diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index be1553c4bbe..5e321686531 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -144,8 +144,8 @@ tf_cuda_library( name = "gpu_runtime", srcs = glob( [ - "common_runtime/gpu/**/*.h", - "common_runtime/gpu/**/*.cc", + "common_runtime/gpu/*.h", + "common_runtime/gpu/*.cc", ], exclude = [ "**/*main.cc", @@ -628,6 +628,7 @@ filegroup( "//tensorflow/core:kernels/relu_op.h", "//tensorflow/core:kernels/softplus_op.cc", "//tensorflow/core:kernels/softplus_op.h", + "//tensorflow/core:kernels/stack_ops.cc", "//tensorflow/core:kernels/transpose_op.cc", "//tensorflow/core:kernels/transpose_op.h", "//tensorflow/core:kernels/transpose_op_functor.h", @@ -673,6 +674,7 @@ cc_library( copts = [ "-mfpu=neon", "-std=c++11", + "-O2", ], tags = [ "manual", diff --git a/tensorflow/core/framework/op.cc b/tensorflow/core/framework/op.cc index adafb7355dc..cd373af7124 100644 --- a/tensorflow/core/framework/op.cc +++ b/tensorflow/core/framework/op.cc @@ -60,6 +60,7 @@ const OpDef* OpRegistry::LookUp(const string& op_type_name, if (op_def == nullptr) { status->Update( errors::NotFound("Op type not registered '", op_type_name, "'")); + LOG(INFO) << status->ToString(); static bool first_unregistered = true; if (first_unregistered) { OpList op_list; diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index 0fd9e666701..e819d8755e2 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -817,6 +817,11 @@ class OpKernelContext { return output_allocation_types_[index]; } + // Per-step resource manager for use by white-listed internal ops. + ResourceMgr* step_resource_manager() const { + return params_.step_resource_manager; + } + private: Allocator* get_allocator(AllocatorAttributes attr) { Allocator* allocator = params_.device->GetAllocator(attr); @@ -836,13 +841,6 @@ class OpKernelContext { } } - // Per-step resource manager for use by white-listed internal ops. - friend class TemporaryVariableOp; - friend class DestroyTemporaryVariableOp; - ResourceMgr* step_resource_manager() const { - return params_.step_resource_manager; - } - // Internal common method used when allocating tensor memory Status allocate_tensor(DataType type, const TensorShape& shape, Tensor* out_tensor, AllocatorAttributes attr); diff --git a/tensorflow/core/framework/tensor_util_test.cc b/tensorflow/core/framework/tensor_util_test.cc index c485fa26d2a..0602fbea512 100644 --- a/tensorflow/core/framework/tensor_util_test.cc +++ b/tensorflow/core/framework/tensor_util_test.cc @@ -140,7 +140,7 @@ TEST(TensorUtil, Concat) { std::vector<Tensor> to_concat; int64 total_size = 0; int offset = 0; - for (int entry = 0; entry < sizes.size(); ++entry) { + for (size_t entry = 0; entry < sizes.size(); ++entry) { const int64 size = sizes[entry]; Tensor tensor(DT_INT32, TensorShape({size, 2})); for (int i = offset; i < offset + size; ++i) { @@ -175,7 +175,7 @@ TEST(TensorUtil, Split) { ASSERT_EQ(sizes.size(), splits.size()); int offset = 0; - for (int entry = 0; entry < splits.size(); ++entry) { + for (size_t entry = 0; entry < splits.size(); ++entry) { const int64 size = sizes[entry]; const Tensor& split = splits[entry]; diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc index b7bb037a5fd..abb3d686761 100644 --- a/tensorflow/core/graph/graph_partition.cc +++ b/tensorflow/core/graph/graph_partition.cc @@ -1011,6 +1011,10 @@ Status Partition(const PartitionOptions& opts, Graph* g, if (!edge->IsControlEdge() && IsRefType(src->output_type(edge->src_output()))) { + AddNodeAttr("_start_time", recv_start_time, recv); + if (real_recv != recv) { + AddNodeAttr("_start_time", recv_start_time, real_recv); + } // If src is of ref type and the edge is not a control edge, dst has // read semantics and therefore we must control the recv. ref_recvs.push_back(real_recv); diff --git a/tensorflow/core/kernels/bias_op_gpu.cu.cc b/tensorflow/core/kernels/bias_op_gpu.cu.cc index 244b71f9dd2..a0827466e0c 100644 --- a/tensorflow/core/kernels/bias_op_gpu.cu.cc +++ b/tensorflow/core/kernels/bias_op_gpu.cu.cc @@ -37,7 +37,7 @@ __global__ void BiasOpCustomKernel(int nthreads, const T* input, const T* bias, T* output) { CUDA_1D_KERNEL_LOOP(index, nthreads) { int bias_offset = index % bias_size; - output[index] = __ldg(input + index) + __ldg(bias + bias_offset); + output[index] = ldg(input + index) + ldg(bias + bias_offset); } } diff --git a/tensorflow/core/kernels/constant_op_gpu.cu.cc b/tensorflow/core/kernels/constant_op_gpu.cu.cc index d68b014fe7a..5991391850a 100644 --- a/tensorflow/core/kernels/constant_op_gpu.cu.cc +++ b/tensorflow/core/kernels/constant_op_gpu.cu.cc @@ -42,9 +42,10 @@ struct scalar_const_op { return *val; } - template <typename Index> - EIGEN_STRONG_INLINE const Packet packetOp(Index, Index = 0) const { - return internal::pset1<Packet>(*val); + template <typename Index, typename PacketType = Packet> + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const PacketType + packetOp(Index, Index = 0) const { + return internal::pset1<PacketType>(*val); } }; diff --git a/tensorflow/core/kernels/conv_grad_ops.cc b/tensorflow/core/kernels/conv_grad_ops.cc index 10a3558013c..dae06f4bfc7 100644 --- a/tensorflow/core/kernels/conv_grad_ops.cc +++ b/tensorflow/core/kernels/conv_grad_ops.cc @@ -383,12 +383,53 @@ class Conv2DCustomBackpropInputOp : public OpKernel { // The output image size is the spatial size of the output. const int output_image_size = out_rows * out_cols; + // TODO(andydavis) Get L2/L3 cache sizes from device. + const size_t l2_cache_size = 256LL << 10; + const size_t l3_cache_size = 30LL << 20; + + // Use L3 cache size as target working set size. + const size_t target_working_set_size = l3_cache_size / sizeof(T); + + // Calculate size of matrices involved in MatMul: C = A x B. + const size_t size_A = output_image_size * out_depth; + + const size_t size_B = filter_total_size * out_depth; + + const size_t size_C = output_image_size * filter_total_size; + + const size_t work_unit_size = size_A + size_B + size_C; + + auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); + + // Calculate per-thread work unit size. + const size_t thread_work_unit_size = + work_unit_size / worker_threads.num_threads; + + // Set minimum per-thread work unit size to size of L2 cache. + const size_t min_thread_work_unit_size = l2_cache_size / sizeof(T); + + // Use parallel tensor contractions if there is no batching, or if the + // minimum per-thread work unit size threshold has been exceeded. + // Otherwise, revert to multiple single-threaded matmul ops running in + // parallel to keep all threads busy. + // TODO(andydavis) Explore alternatives to branching the code in this way + // (i.e. run multiple, parallel tensor contractions in another thread pool). + const bool use_parallel_contraction = + batch == 1 || thread_work_unit_size >= min_thread_work_unit_size; + + const size_t shard_size = + use_parallel_contraction + ? 1 + : (target_working_set_size + work_unit_size - 1) / work_unit_size; + Tensor col_buffer; - OP_REQUIRES_OK( - context, - context->allocate_temp( - DataTypeToEnum<T>::value, - TensorShape({output_image_size, filter_total_size}), &col_buffer)); + OP_REQUIRES_OK(context, + context->allocate_temp( + DataTypeToEnum<T>::value, + TensorShape({static_cast<int64>(shard_size), + static_cast<int64>(output_image_size), + static_cast<int64>(filter_total_size)}), + &col_buffer)); // The input offset corresponding to a single input image. const int input_offset = input_rows * input_cols * in_depth; @@ -400,31 +441,74 @@ class Conv2DCustomBackpropInputOp : public OpKernel { auto* out_backprop_data = out_backprop.template flat<T>().data(); auto* input_backprop_data = in_backprop->template flat<T>().data(); - typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, - Eigen::Unaligned> TensorMap; - typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>, - Eigen::Unaligned> ConstTensorMap; + if (use_parallel_contraction) { + typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, + Eigen::Unaligned> TensorMap; + typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>, + Eigen::Unaligned> ConstTensorMap; - // Initialize contraction dims (we need to transpose 'B' below). - Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_dims; - contract_dims[0].first = 1; - contract_dims[0].second = 1; + // Initialize contraction dims (we need to transpose 'B' below). + Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_dims; + contract_dims[0].first = 1; + contract_dims[0].second = 1; - for (int image_id = 0; image_id < batch; ++image_id) { - // Compute gradient into col_buffer. - TensorMap C(col_buffer_data, output_image_size, filter_total_size); + for (int image_id = 0; image_id < batch; ++image_id) { + // Compute gradient into col_buffer. + TensorMap C(col_buffer_data, output_image_size, filter_total_size); - ConstTensorMap A(out_backprop_data + output_offset * image_id, - output_image_size, out_depth); - ConstTensorMap B(filter_data, filter_total_size, out_depth); + ConstTensorMap A(out_backprop_data + output_offset * image_id, + output_image_size, out_depth); + ConstTensorMap B(filter_data, filter_total_size, out_depth); - C.device(context->eigen_cpu_device()) = A.contract(B, contract_dims); + C.device(context->eigen_cpu_device()) = A.contract(B, contract_dims); - Col2im<T>(col_buffer_data, in_depth, input_rows, input_cols, filter_rows, - filter_cols, pad_top, pad_left, pad_bottom, pad_right, stride, - stride, input_backprop_data); + Col2im<T>(col_buffer_data, in_depth, input_rows, input_cols, + filter_rows, filter_cols, pad_top, pad_left, pad_bottom, + pad_right, stride, stride, input_backprop_data); - input_backprop_data += input_offset; + input_backprop_data += input_offset; + } + } else { + typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, + Eigen::RowMajor>> MatrixMap; + typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, + Eigen::RowMajor>> ConstMatrixMap; + + for (int image_id = 0; image_id < batch; image_id += shard_size) { + const int shard_limit = std::min(static_cast<int>(shard_size), + static_cast<int>(batch) - image_id); + + auto shard = [&in_depth, &input_rows, &input_cols, &filter_rows, + &filter_cols, &pad_top, &pad_left, &pad_bottom, + &pad_right, &stride, &output_image_size, + &filter_total_size, &out_depth, &input_backprop_data, + &col_buffer_data, &out_backprop_data, &filter_data, + &input_offset, &output_offset, + &size_C](int64 start, int64 limit) { + for (int shard_id = start; shard_id < limit; ++shard_id) { + T* im2col_buf = col_buffer_data + shard_id * size_C; + T* input_data = input_backprop_data + shard_id * input_offset; + const T* out_data = out_backprop_data + shard_id * output_offset; + + // Compute gradient into 'im2col_buf'. + MatrixMap C(im2col_buf, output_image_size, filter_total_size); + + ConstMatrixMap A(out_data, output_image_size, out_depth); + ConstMatrixMap B(filter_data, filter_total_size, out_depth); + + C.noalias() = A * B.transpose(); + + Col2im<T>(im2col_buf, in_depth, input_rows, input_cols, filter_rows, + filter_cols, pad_top, pad_left, pad_bottom, pad_right, + stride, stride, input_data); + } + }; + Shard(worker_threads.num_threads, worker_threads.workers, shard_limit, + work_unit_size, shard); + + input_backprop_data += input_offset * shard_limit; + out_backprop_data += output_offset * shard_limit; + } } } @@ -620,8 +704,8 @@ class Conv2DCustomBackpropFilterOp : public OpKernel { &pad_left, &pad_bottom, &pad_right, &stride, &input_offset, &size_A](int64 start, int64 limit) { for (int shard_id = start; shard_id < limit; ++shard_id) { - auto input_data_shard = input_data + shard_id * input_offset; - auto col_data_shard = col_buffer_data + shard_id * size_A; + const T* input_data_shard = input_data + shard_id * input_offset; + T* col_data_shard = col_buffer_data + shard_id * size_A; // When we compute the gradient with respect to the filters, we need // to do im2col to allow gemm-type computation. diff --git a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc index a40b3caefd6..2d7e149c7b5 100644 --- a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc +++ b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc @@ -30,7 +30,7 @@ typedef Eigen::GpuDevice GPUDevice; namespace functor { // A simple array that contains data that can be passed between CPU and GPU. -template <typename T, int IndexCount> +template <typename T, int IndexCount, T DefaultValue> struct Array { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T& operator[](int index) const { return data[index]; @@ -38,16 +38,57 @@ struct Array { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T& operator[](int index) { return data[index]; } - int data[IndexCount]; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Array() { + for (int i = 0; i < IndexCount; i++) { + data[i] = DefaultValue; + } + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Array(T a0) { + data[0] = a0; + for (int i = 1; i < IndexCount; i++) { + data[i] = DefaultValue; + } + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Array(T a0, T a1) { + data[0] = a0; + data[1] = a1; + for (int i = 2; i < IndexCount; i++) { + data[i] = DefaultValue; + } + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Array(T a0, T a1, T a2) { + data[0] = a0; + data[1] = a1; + data[2] = a2; + for (int i = 3; i < IndexCount; i++) { + data[i] = DefaultValue; + } + } + T data[IndexCount]; }; // A dimension type with compile-time known size. template <int IndexCount> -struct Dimension : Array<int, IndexCount> {}; +struct Dimension : Array<int, IndexCount, 1> { + typedef Array<int, IndexCount, 1> Base; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Dimension() : Base() {} + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Dimension(int a0) : Base(a0) {} + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Dimension(int a0, int a1) + : Base(a0, a1) {} + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Dimension(int a0, int a1, int a2) + : Base(a0, a1, a2) {} +}; // An index type with compile-time known size. template <int IndexCount> -struct Index : Array<int, IndexCount> {}; +struct Index : Array<int, IndexCount, 0> { + typedef Array<int, IndexCount, 0> Base; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index() : Base() {} + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index(int a0) : Base(a0) {} + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index(int a0, int a1) : Base(a0, a1) {} + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index(int a0, int a1, int a2) + : Base(a0, a1, a2) {} +}; // A helper function that converts a tensor index into a flat array index. template <int IndexCount> @@ -94,7 +135,7 @@ __global__ void SwapDimension0And2InTensor3(int nthreads, const T* input, int input_index = TensorIndexToFlat(input_tensor_index, input_dims); - output[output_index] = __ldg(input + input_index); + output[output_index] = ldg(input + input_index); } } @@ -119,7 +160,88 @@ __global__ void SwapDimension1And2InTensor3(int nthreads, const T* input, int input_index = TensorIndexToFlat(input_tensor_index, input_dims); - output[output_index] = __ldg(input + input_index); + output[output_index] = ldg(input + input_index); + } +} + +// Use shared memory tiles to swap dimension-1 and dimension-2 of a 3D tensor, +// where dimensions are zero-based: output[i][j][k] = input[i][k][j]. +// TileSize could be arbitrary. But for best performance, it is better to be +// the same as number of threads in a warp, which is 32 for almost all GPU +// architectures. +template <typename T, int TileSize> +__global__ void SwapDimension1And2InTensor3UsingTiles(const T* input, + Dimension<3> input_dims, + T* output) { + // One extra line in the inner dimension to avoid share memory bank conflict. + __shared__ T shared_memory_tile[TileSize][TileSize + 1]; + + int x = threadIdx.x; + if (x >= TileSize) { + return; + } + + Dimension<3> output_dims = { + input_dims[0], input_dims[2], input_dims[1], + }; + + Dimension<3> input_dims_in_tiles = { + input_dims[0], (input_dims[1] + TileSize - 1) / TileSize, + (input_dims[2] + TileSize - 1) / TileSize, + }; + + Index<3> input_tile_index = + FlatToTensorIndex(blockIdx.x, input_dims_in_tiles); + + Index<3> input_tile_origin = { + input_tile_index[0], input_tile_index[1] * TileSize, + input_tile_index[2] * TileSize, + }; + + int input_origin_flat_index = + TensorIndexToFlat(input_tile_origin, input_dims); + + int tile_width = TileSize; + // Only the last row or column may not have the full size. + if (input_tile_index[2] == input_dims_in_tiles[2] - 1) { + tile_width = input_dims[2] - (input_dims_in_tiles[2] - 1) * TileSize; + } + int tile_height = TileSize; + if (input_tile_index[1] == input_dims_in_tiles[1] - 1) { + tile_height = input_dims[1] - (input_dims_in_tiles[1] - 1) * TileSize; + } + + // Load the data from input memory to the shared memory tile. + if (x < tile_width) { + int input_flat_index = input_origin_flat_index + x; + for (int y = 0; y < tile_height; y++) { + shared_memory_tile[y][x] = input[input_flat_index]; + input_flat_index += input_dims[2]; + } + } + + __syncthreads(); + + Index<3> output_tile_index = { + input_tile_index[0], input_tile_index[2], input_tile_index[1], + }; + + Index<3> output_tile_origin = { + output_tile_index[0], output_tile_index[1] * TileSize, + output_tile_index[2] * TileSize, + }; + + int output_origin_flat_index = + TensorIndexToFlat(output_tile_origin, output_dims); + + int output_flat_index = output_origin_flat_index + x; + + // Load the data from the shared memory tile to the output memory. + if (x < tile_height) { + for (int y = 0; y < tile_width; y++) { + output[output_flat_index] = shared_memory_tile[x][y]; + output_flat_index += output_dims[2]; + } } } @@ -212,6 +334,39 @@ struct PadInput<GPUDevice, T, int> { } }; +// Launch the GPU kernel that would swap dimension-1 and dimension-2 in a +// 3D tensor. It looks at the shape of the incoming data, and decides the best +// strategy to launch. +template <typename T> +void RunSwapDimension1And2InTensor3(const GPUDevice& d, const T* input, + const Dimension<3>& input_dims, T* output) { + // If both dimensions are not trivial, use tiles for the actual swapping. + // Otherwise, the trivial swapping relying on the ldg cache is more efficient. + static const int kMinDimensionToUseTiles = 16; + bool use_tiles = (input_dims[1] >= kMinDimensionToUseTiles && + input_dims[2] >= kMinDimensionToUseTiles); + if (use_tiles) { + // The tile-size can be chosen to be arbitrary number. But it is better to + // be the same as number of threads in a warp, which is 32. + static const int TileSize = 32; + Dimension<3> input_dims_in_tiles = { + input_dims[0], (input_dims[1] + TileSize - 1) / TileSize, + (input_dims[2] + TileSize - 1) / TileSize, + }; + int total_tiles_count = input_dims_in_tiles[0] * input_dims_in_tiles[1] * + input_dims_in_tiles[2]; + SwapDimension1And2InTensor3UsingTiles< + T, TileSize><<<total_tiles_count, TileSize, 0, d.stream()>>>( + input, input_dims, output); + } else { + int total_element_count = input_dims[0] * input_dims[1] * input_dims[2]; + CudaLaunchConfig config = GetCudaLaunchConfig(total_element_count, d); + SwapDimension1And2InTensor3< + T><<<config.block_count, config.thread_per_block, 0, d.stream()>>>( + config.virtual_thread_count, input, input_dims, output); + } +} + // A GPU helper functor that converts NHWC TensorFlow data format to // NCHW format that is accepted by Cudnn. template <typename T> @@ -223,10 +378,7 @@ struct NHWCToNCHW<GPUDevice, T> { combined_dims[0] = in.dimension(0); combined_dims[1] = in.dimension(1) * in.dimension(2); combined_dims[2] = in.dimension(3); - CudaLaunchConfig config = GetCudaLaunchConfig(out.size(), d); - SwapDimension1And2InTensor3< - T><<<config.block_count, config.thread_per_block, 0, d.stream()>>>( - config.virtual_thread_count, in.data(), combined_dims, out.data()); + RunSwapDimension1And2InTensor3(d, in.data(), combined_dims, out.data()); } }; @@ -241,10 +393,7 @@ struct NCHWToNHWC<GPUDevice, T> { combined_dims[0] = in.dimension(0); combined_dims[1] = in.dimension(1); combined_dims[2] = in.dimension(2) * in.dimension(3); - CudaLaunchConfig config = GetCudaLaunchConfig(out.size(), d); - SwapDimension1And2InTensor3< - T><<<config.block_count, config.thread_per_block, 0, d.stream()>>>( - config.virtual_thread_count, in.data(), combined_dims, out.data()); + RunSwapDimension1And2InTensor3(d, in.data(), combined_dims, out.data()); } }; diff --git a/tensorflow/core/kernels/lrn_op.cc b/tensorflow/core/kernels/lrn_op.cc index ddf68f8ff49..fb779f24665 100644 --- a/tensorflow/core/kernels/lrn_op.cc +++ b/tensorflow/core/kernels/lrn_op.cc @@ -66,7 +66,7 @@ class LRNOp : public OpKernel { context->allocate_output( 0, TensorShape({batch, rows, cols, depth}), &output)); -#if !defined(__ANDROID__) +#if defined(__ANDROID__) MognetLRN(in, batch, rows, cols, depth, output); #else const int nodes = cols * rows; diff --git a/tensorflow/core/kernels/resize_nearest_neighbor_op.cc b/tensorflow/core/kernels/resize_nearest_neighbor_op.cc index 17ce90b8126..c4eea44044c 100644 --- a/tensorflow/core/kernels/resize_nearest_neighbor_op.cc +++ b/tensorflow/core/kernels/resize_nearest_neighbor_op.cc @@ -49,13 +49,16 @@ class ResizeNearestNeighborOp : public OpKernel { errors::InvalidArgument("shape_t must have two elements", shape_t.shape().ShortDebugString())); - auto Svec = shape_t.vec<int32>(); + auto sizes = shape_t.vec<int32>(); + OP_REQUIRES(context, sizes(0) > 0 && sizes(1) > 0, + errors::InvalidArgument("shape_t's elements must be positive")); + // Initialize shape to the batch size of the input, then add // the rest of the dimensions Tensor* output = nullptr; OP_REQUIRES_OK(context, context->allocate_output( - 0, TensorShape({input.dim_size(0), Svec(0), - Svec(1), input.dim_size(3)}), + 0, TensorShape({input.dim_size(0), sizes(0), + sizes(1), input.dim_size(3)}), &output)); const int64 batch_size = input.dim_size(0); @@ -87,12 +90,93 @@ class ResizeNearestNeighborOp : public OpKernel { } }; -#define REGISTER_KERNEL(T) \ - REGISTER_KERNEL_BUILDER(Name("ResizeNearestNeighbor") \ - .Device(DEVICE_CPU) \ - .TypeConstraint<T>("T") \ - .HostMemory("size"), \ - ResizeNearestNeighborOp<CPUDevice, T>); +template <typename Device, typename T> +class ResizeNearestNeighborOpGrad : public OpKernel { + public: + explicit ResizeNearestNeighborOpGrad(OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + // Grab and validate the input: + const Tensor& input = context->input(0); + OP_REQUIRES(context, input.dims() == 4, + errors::InvalidArgument("input must be 4-dimensional", + input.shape().ShortDebugString())); + + // Grab and validate the output shape: + const Tensor& shape_t = context->input(1); + OP_REQUIRES(context, shape_t.dims() == 1, + errors::InvalidArgument("shape_t must be 1-dimensional", + shape_t.shape().ShortDebugString())); + OP_REQUIRES(context, shape_t.NumElements() == 2, + errors::InvalidArgument("shape_t must have two elements", + shape_t.shape().ShortDebugString())); + + auto sizes = shape_t.vec<int32>(); + OP_REQUIRES(context, sizes(0) > 0 && sizes(1) > 0, + errors::InvalidArgument("shape_t's elements must be positive")); + + // Initialize shape to the batch size of the input, then add + // the rest of the dimensions + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output( + 0, TensorShape({input.dim_size(0), sizes(0), + sizes(1), input.dim_size(3)}), + &output)); + + const int64 batch_size = input.dim_size(0); + const int64 in_height = input.dim_size(1); + const int64 in_width = input.dim_size(2); + const int64 channels = input.dim_size(3); + + const int64 out_height = output->dim_size(1); + const int64 out_width = output->dim_size(2); + + typename TTypes<T, 4>::ConstTensor input_data = input.tensor<T, 4>(); + typename TTypes<T, 4>::Tensor output_data = output->tensor<T, 4>(); + + const float height_scale = out_height / static_cast<float>(in_height); + const float width_scale = out_width / static_cast<float>(in_width); + + for (int c = 0; c < channels; ++c) { + for (int y = 0; y < out_height; ++y) { + for (int x = 0; x < out_width; ++x) { + for (int b = 0; b < batch_size; ++b) { + output_data(b, y, x, c) = 0; + } + } + } + } + + for (int c = 0; c < channels; ++c) { + for (int y = 0; y < in_height; ++y) { + const int out_y = std::min(static_cast<int64>(floorf(y * height_scale)), + (out_height - 1)); + + for (int x = 0; x < in_width; ++x) { + const int out_x = std::min( + static_cast<int64>(floorf(x * width_scale)), (out_width - 1)); + + for (int b = 0; b < batch_size; ++b) { + output_data(b, out_y, out_x, c) += input_data(b, y, x, c); + } + } + } + } + } +}; + +#define REGISTER_KERNEL(T) \ + REGISTER_KERNEL_BUILDER(Name("ResizeNearestNeighbor") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<T>("T") \ + .HostMemory("size"), \ + ResizeNearestNeighborOp<CPUDevice, T>); \ + REGISTER_KERNEL_BUILDER(Name("ResizeNearestNeighborGrad") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<T>("T") \ + .HostMemory("size"), \ + ResizeNearestNeighborOpGrad<CPUDevice, T>); REGISTER_KERNEL(uint8); REGISTER_KERNEL(int8); diff --git a/tensorflow/core/kernels/sparse_matmul_op.cc b/tensorflow/core/kernels/sparse_matmul_op.cc index 36f262032d6..c0aecac1adf 100644 --- a/tensorflow/core/kernels/sparse_matmul_op.cc +++ b/tensorflow/core/kernels/sparse_matmul_op.cc @@ -141,8 +141,7 @@ void SparseSlice::Initialize(const ConstMatrixMap& mat, int col_offset) { Index idx; int data3_size = 0; for (int i = 0; i < num_blocks; ++i) { - int num_block_cols = - std::min(block_size, num_cols - block_size * (num_blocks - 1)); + int num_block_cols = std::min(block_size, num_cols - block_size * i); for (int row = 0; row < num_rows; ++row) { idx3.m = static_cast<uint8>(row); const float* start = @@ -457,13 +456,11 @@ class SparseMatMulOp : public OpKernel { errors::InvalidArgument("b is not a matrix")); auto left = a.matrix<float>(); - auto right_mat = b.matrix<float>(); + auto right = b.matrix<float>(); const int m = transpose_a_ ? left.dimension(1) : left.dimension(0); const int k = transpose_a_ ? left.dimension(0) : left.dimension(1); - const int n = - transpose_b_ ? right_mat.dimension(0) : right_mat.dimension(1); - const int k2 = - transpose_b_ ? right_mat.dimension(1) : right_mat.dimension(0); + const int n = transpose_b_ ? right.dimension(0) : right.dimension(1); + const int k2 = transpose_b_ ? right.dimension(1) : right.dimension(0); OP_REQUIRES(ctx, k == k2, errors::InvalidArgument("Matrix size incompatible: a: ", @@ -473,7 +470,7 @@ class SparseMatMulOp : public OpKernel { OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({m, n}), &output)); auto out = output->matrix<float>(); - if (!a_is_sparse_) { + if (!a_is_sparse_ && !b_is_sparse_) { // Fallback to Eigen contract. // Note that we currently don't optimize the case where only right is // sparse. That can generally be handled by tranposing the order of the @@ -482,26 +479,41 @@ class SparseMatMulOp : public OpKernel { dim_pair[0].first = transpose_a_ ? 0 : 1; dim_pair[0].second = transpose_b_ ? 1 : 0; out.device(ctx->template eigen_device<CPUDevice>()) = - left.contract(right_mat, dim_pair); + left.contract(right, dim_pair); return; } + auto left_mat = &left; + auto right_mat = &right; + bool transpose_output = false; + bool transpose_a = transpose_a_; + bool transpose_b = transpose_b_; + if (!a_is_sparse_) { + // Swap the order of multiplications using the identity: + // A * B = (B' * A')'. + std::swap(left_mat, right_mat); + std::swap(transpose_a, transpose_b); + transpose_a = !transpose_a; + transpose_b = !transpose_b; + transpose_output = !transpose_output; + } std::unique_ptr<Matrix> right_tr_mat; std::unique_ptr<TTypes<float>::ConstMatrix> right_tr_map; - if (transpose_b_) { + if (transpose_b) { // TODO(agarwal): avoid transposing the matrix here and directly handle // transpose in CreateDenseSlices. - right_tr_mat.reset(new Matrix(k, n)); + right_tr_mat.reset( + new Matrix(right_mat->dimension(1), right_mat->dimension(0))); Eigen::array<int, 2> perm({1, 0}); right_tr_mat->device(ctx->template eigen_device<CPUDevice>()) = - right_mat.shuffle(perm); + right_mat->shuffle(perm); right_tr_map.reset(new TTypes<float>::ConstMatrix( right_tr_mat->data(), right_tr_mat->dimensions())); + right_mat = right_tr_map.get(); } - TTypes<float>::ConstMatrix& right = - transpose_b_ ? *right_tr_map : right_mat; - SparseMatMul(left, right, transpose_a_, - ctx->device()->tensorflow_cpu_worker_threads(), &out); + SparseMatMul(*left_mat, *right_mat, transpose_a, + ctx->device()->tensorflow_cpu_worker_threads(), + transpose_output, &out); } private: @@ -510,7 +522,7 @@ class SparseMatMulOp : public OpKernel { static inline void SparseMatMul( const ConstMatrixMap& left, const ConstMatrixMap& right, bool transpose_left, const DeviceBase::CpuWorkerThreads* thread_pool, - MatrixMap* output); + bool transpose_output, MatrixMap* output); // Computes multiplication of left and num_cols columns of right, and stores // the output block in *"output" at offsets "output_row_offset" and @@ -518,9 +530,9 @@ class SparseMatMulOp : public OpKernel { // else adds the values to the existing values. static inline void ComputeOutputBlock(const std::vector<SparseSlice*>& left, const ConstMatrixMap& right, - const int num_cols, - int output_row_offset, + int num_cols, int output_row_offset, int output_col_offset, bool assign, + bool transpose_output, MatrixMap* output); // Encodes "mat" using a sparse representation and stores that in @@ -578,9 +590,10 @@ class SparseMatMulOp : public OpKernel { inline void SparseMatMulOp::ComputeOutputBlock( const std::vector<SparseSlice*>& left, const ConstMatrixMap& right, - const int num_cols, int output_row_offset, int output_col_offset, - bool assign, MatrixMap* output) { - const int num_rows = left[0]->num_rows; + int num_cols, int output_row_offset, int output_col_offset, bool assign, + bool transpose_output, MatrixMap* output) { + static const Eigen::array<int, 2> perm({1, 0}); + int num_rows = left[0]->num_rows; const int rhs_num_cols = right.dimension(1); DCHECK_LE(num_cols, rhs_num_cols); Matrix out(num_rows, rhs_num_cols); @@ -593,18 +606,33 @@ inline void SparseMatMulOp::ComputeOutputBlock( if (!assign) { const Eigen::array<int, 2> begin = {output_row_offset, output_col_offset}; const Eigen::array<int, 2> sizes = {num_rows, num_cols}; - if (num_cols == rhs_num_cols) { - output->slice(begin, sizes) += out; + if (transpose_output) { + if (num_cols == rhs_num_cols) { + output->shuffle(perm).slice(begin, sizes) += out; + } else { + static const Eigen::array<int, 2> zero = {0, 0}; + output->shuffle(perm).slice(begin, sizes) += out.slice(zero, sizes); + } } else { - static const Eigen::array<int, 2> zero = {0, 0}; - output->slice(begin, sizes) += out.slice(zero, sizes); + if (num_cols == rhs_num_cols) { + output->slice(begin, sizes) += out; + } else { + static const Eigen::array<int, 2> zero = {0, 0}; + output->slice(begin, sizes) += out.slice(zero, sizes); + } } } else { - // output->slice(begin, sizes) = out.slice(zero, sizes), implemented - // using memcpy. + std::unique_ptr<Matrix> out_tr; + if (transpose_output) { + out_tr.reset(new Matrix(rhs_num_cols, num_rows)); + *out_tr = out.shuffle(perm); + std::swap(output_row_offset, output_col_offset); + std::swap(num_rows, num_cols); + } + const Matrix& final_out = transpose_output ? *out_tr : out; for (int i = 0; i < num_rows; ++i) { - memcpy(&(*output)(output_row_offset + i, output_col_offset), &out(i, 0), - num_cols * sizeof(float)); + memcpy(&(*output)(output_row_offset + i, output_col_offset), + &final_out(i, 0), num_cols * sizeof(float)); } } } @@ -807,7 +835,7 @@ inline void SparseMatMulOp::ComputeBlockSizes(const ConstMatrixMap& left, inline void SparseMatMulOp::SparseMatMul( const ConstMatrixMap& left, const ConstMatrixMap& right, bool transpose_left, const DeviceBase::CpuWorkerThreads* thread_pool, - MatrixMap* output) { + bool transpose_output, MatrixMap* output) { const int num_threads = thread_pool->num_threads; int KR, NR, KL, JB, IB; ComputeBlockSizes(left, right, transpose_left, num_threads, &KR, &NR, &KL, @@ -868,7 +896,7 @@ inline void SparseMatMulOp::SparseMatMul( tasks.push_back(std::bind( &SparseMatMulOp::ComputeOutputBlock, block_left_slices, std::ref(*right_slices[j_inner]), num_cols, M * i_inner, - N * j_inner + nb * NR, kb == 0, output)); + N * j_inner + nb * NR, kb == 0, transpose_output, output)); } } } diff --git a/tensorflow/core/kernels/sparse_matmul_op_test.cc b/tensorflow/core/kernels/sparse_matmul_op_test.cc index 1146e2e8756..52e9269c8c9 100644 --- a/tensorflow/core/kernels/sparse_matmul_op_test.cc +++ b/tensorflow/core/kernels/sparse_matmul_op_test.cc @@ -29,8 +29,11 @@ random::SimplePhilox rnd(&philox); void Sparsify(Tensor* t, float sparsity) { const int64 N = t->NumElements(); CHECK_LE(sparsity, 1); - if (sparsity <= 0) return; auto flat = t->flat<float>(); + if (sparsity == 1) { + flat.setZero(); + return; + } static const uint32 K = 10000; for (int64 i = 0; i < N; ++i) { if (rnd.Uniform(K) < sparsity * K) { @@ -55,25 +58,21 @@ Node* SparseMatMulNode(Graph* g, Node* in0, Node* in1, bool transpose_a, return ret; } -static Graph* SparseMatMulHelper(Graph* g, int m, int n, int d, float sparsity, - bool transpose_a, bool transpose_b, - bool a_sparse, bool b_sparse) { - a_sparse = a_sparse && (sparsity > 0); - b_sparse = b_sparse && (sparsity > 0); +static Graph* SparseMatMulHelper(Graph* g, int m, int n, int d, + float sparsity_a, float sparsity_b, + bool transpose_a, bool transpose_b) { + bool a_sparse = (sparsity_a > 0); + bool b_sparse = (sparsity_b > 0); auto left_shape = transpose_a ? TensorShape({d, m}) : TensorShape({m, d}); Tensor left(DataTypeToEnum<float>::value, left_shape); left.flat<float>().setRandom(); - if (a_sparse) { - Sparsify(&left, sparsity); - } + Sparsify(&left, sparsity_a); auto right_shape = transpose_b ? TensorShape({n, d}) : TensorShape({d, n}); Tensor right(DataTypeToEnum<float>::value, right_shape); right.flat<float>().setRandom(); - if (b_sparse) { - Sparsify(&right, sparsity); - } + Sparsify(&right, sparsity_b); SparseMatMulNode(g, test::graph::Constant(g, left), test::graph::Constant(g, right), transpose_a, transpose_b, @@ -81,59 +80,62 @@ static Graph* SparseMatMulHelper(Graph* g, int m, int n, int d, float sparsity, return g; } -static Graph* SparseMatMul(int m, int n, int d, float sparsity, - bool transpose_a, bool transpose_b) { +static Graph* SparseMatMul(int m, int n, int d, float sparsity_a, + float sparsity_b, bool transpose_a, + bool transpose_b) { Graph* g = new Graph(OpRegistry::Global()); - return SparseMatMulHelper(g, m, n, d, sparsity, transpose_a, transpose_b, - true, false); + return SparseMatMulHelper(g, m, n, d, sparsity_a, sparsity_b, transpose_a, + transpose_b); } -static Graph* MultiSparseMatMul(int m, int n, int d, float sparsity_a, - float sparsity_b) { +#define BM_SPARSE(M, K, N, S1, S2, TA, TB) \ + static void BM_Sparse##_##M##_##K##_##N##_##S1##_##S2##_##TA##_##TB( \ + int iters) { \ + testing::StopTiming(); \ + testing::ItemsProcessed(static_cast<int64>(iters) * M * K * N * 2); \ + std::string label = \ + strings::Printf("tr_a: %d tr_b: %d sp_a: %0.2f sp_b: %0.2f", TA, TB, \ + S1 / 100.0, S2 / 100.0); \ + testing::SetLabel(label); \ + testing::UseRealTime(); \ + auto g = SparseMatMul(M, N, K, S1 / 100.0, S2 / 100.0, TA, TB); \ + testing::StartTiming(); \ + test::Benchmark("cpu", g).Run(iters); \ + } \ + BENCHMARK(BM_Sparse##_##M##_##K##_##N##_##S1##_##S2##_##TA##_##TB); + +BM_SPARSE(2048, 2048, 2048, 0, 0, false, false); +BM_SPARSE(2048, 2048, 2048, 1, 0, false, false); +BM_SPARSE(2048, 2048, 2048, 50, 0, false, false); +BM_SPARSE(2048, 2048, 2048, 85, 0, false, false); +BM_SPARSE(2048, 2048, 2048, 99, 0, false, false); + +BM_SPARSE(2048, 2048, 2048, 0, 50, false, false); +BM_SPARSE(2048, 2048, 2048, 0, 85, false, false); + +BM_SPARSE(2048, 2048, 2048, 85, 0, true, false); +BM_SPARSE(2048, 2048, 2048, 85, 0, false, true); +BM_SPARSE(2048, 2048, 2048, 85, 0, true, true); + +BM_SPARSE(2048, 2048, 2048, 0, 85, true, false); +BM_SPARSE(2048, 2048, 2048, 0, 85, false, true); +BM_SPARSE(2048, 2048, 2048, 0, 85, true, true); + +BM_SPARSE(1024, 1024, 1024, 0, 0, false, false); +BM_SPARSE(1024, 1024, 1024, 1, 0, false, false); +BM_SPARSE(1024, 1024, 1024, 85, 0, false, false); + +BM_SPARSE(256, 256, 256, 1, 0, false, false); +BM_SPARSE(512, 512, 512, 1, 0, false, false); + +static Graph* MultiSparseMatMul(int m, int n, int d, float sparsity_1, + float sparsity_2) { Graph* g = new Graph(OpRegistry::Global()); - if (sparsity_a == 0 && sparsity_b > 0) { - SparseMatMulHelper(g, m, n, d, sparsity_a, false, false, false, false); - SparseMatMulHelper(g, n, d, m, sparsity_b, true, true, true, false); - SparseMatMulHelper(g, m, d, n, sparsity_b, false, false, true, false); - } else { - SparseMatMulHelper(g, m, n, d, sparsity_a, false, true, true, false); - SparseMatMulHelper(g, d, n, m, sparsity_a, true, false, true, true); - SparseMatMulHelper(g, m, d, n, sparsity_b, false, false, true, false); - } + SparseMatMulHelper(g, d, n, m, sparsity_1, sparsity_2, true, false); + SparseMatMulHelper(g, m, d, n, sparsity_2, 0, false, true); return g; } -#define BM_SPARSE(M, K, N, S) \ - static void BM_Sparse##_##M##_##K##_##N##_##S(int iters) { \ - testing::StopTiming(); \ - testing::ItemsProcessed(static_cast<int64>(iters) * M * K * N * 2); \ - std::string label; \ - if (S == 0) { \ - label = strings::Printf("%d*%d*%d_Eigen", M, K, N); \ - } else { \ - label = strings::Printf("%d*%d*%d_sparsity:%0.2f", M, K, N, S / 100.0); \ - } \ - testing::SetLabel(label); \ - testing::UseRealTime(); \ - auto g = SparseMatMul(M, N, K, S / 100.0, false, false); \ - testing::StartTiming(); \ - test::Benchmark("cpu", g).Run(iters); \ - } \ - BENCHMARK(BM_Sparse##_##M##_##K##_##N##_##S); - -BM_SPARSE(2048, 2048, 2048, 0); -BM_SPARSE(2048, 2048, 2048, 1); -BM_SPARSE(2048, 2048, 2048, 50); -BM_SPARSE(2048, 2048, 2048, 85); -BM_SPARSE(2048, 2048, 2048, 99); - -BM_SPARSE(1024, 1024, 1024, 0); -BM_SPARSE(1024, 1024, 1024, 1); -BM_SPARSE(1024, 1024, 1024, 85); - -BM_SPARSE(256, 256, 256, 1); -BM_SPARSE(512, 512, 512, 1); - #define BM_SPARSE_MULTI(M, K, N, S1, S2) \ static void BM_Sparse_Multi##_##M##_##K##_##N##_##S1##_##S2(int iters) { \ testing::StopTiming(); \ @@ -151,22 +153,4 @@ BM_SPARSE(512, 512, 512, 1); BM_SPARSE_MULTI(1024, 2140, 4096, 0, 82); BM_SPARSE_MULTI(1024, 4096, 2048, 83, 83); -#define BM_SPARSE_TR(M, K, N, S, TA, TB) \ - static void BM_Sparse##_##M##_##K##_##N##_##S##_##TA##_##TB(int iters) { \ - testing::StopTiming(); \ - testing::ItemsProcessed(static_cast<int64>(iters) * M * K * N * 2); \ - std::string label = \ - strings::Printf("%d_%d_%d_%d_%d_%0.2f", M, K, N, TA, TB, S / 100.0); \ - testing::SetLabel(label); \ - testing::UseRealTime(); \ - auto g = SparseMatMul(M, N, K, S / 100.0, TA, TB); \ - testing::StartTiming(); \ - test::Benchmark("cpu", g).Run(iters); \ - } \ - BENCHMARK(BM_Sparse##_##M##_##K##_##N##_##S##_##TA##_##TB); - -BM_SPARSE_TR(2048, 2048, 2048, 1, true, false); -BM_SPARSE_TR(2048, 2048, 2048, 1, false, true); -BM_SPARSE_TR(2048, 2048, 2048, 1, true, true); - } // end namespace tensorflow diff --git a/tensorflow/core/kernels/stack_ops.cc b/tensorflow/core/kernels/stack_ops.cc new file mode 100644 index 00000000000..055050cd34a --- /dev/null +++ b/tensorflow/core/kernels/stack_ops.cc @@ -0,0 +1,181 @@ +// See docs in ../ops/data_flow_ops.cc. + +#include <limits.h> +#include <vector> + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/public/tensor.h" +#include "tensorflow/core/public/tensor_shape.h" + +namespace tensorflow { + +class Stack : public ResourceBase { + public: + Stack(const DataType& elem_type, const Tensor& handle) + : elem_type_(elem_type), handle_(handle) {} + + void Push(const PersistentTensor& value) { + mutex_lock l(mu_); + stack_.push_back(value); + } + + bool Pop(PersistentTensor* value) { + mutex_lock l(mu_); + if (!stack_.empty()) { + *value = stack_.back(); + stack_.pop_back(); + return true; + } + return false; + } + + DataType ElemType() { return elem_type_; } + + string DebugString() override { + mutex_lock l(mu_); + return strings::StrCat("#elem:", stack_.size()); + } + + private: + friend class StackOp; + mutex* mu() { return &mu_; } + Tensor* handle() { return &handle_; } + + mutex mu_; + DataType elem_type_; + Tensor handle_; + std::vector<PersistentTensor> stack_ GUARDED_BY(mu_); +}; + +// A per-run local stack. The stack uses a "per-step" resource manager which +// ensures that correct garbage collection on error or successful completion. +class StackOp : public OpKernel { + public: + explicit StackOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("elem_type", &elem_type_)); + OP_REQUIRES_OK(context, context->GetAttr("stack_name", &stack_name_)); + if (stack_name_ == "") stack_name_ = name(); + } + + void Compute(OpKernelContext* ctx) override { + // Create the stack handle. + Tensor stack_handle; + AllocatorAttributes alloc_attr; + alloc_attr.set_on_host(true); + OP_REQUIRES_OK(ctx, ctx->allocate_temp(tensorflow::DT_STRING, + tensorflow::TensorShape({2}), + &stack_handle, alloc_attr)); + auto handle = stack_handle.flat<string>(); + handle(0) = "_stacks"; + handle(1) = stack_name_; + // Store the handle in a container of the per-step RM. + ResourceMgr* rm = ctx->step_resource_manager(); + OP_REQUIRES(ctx, rm != nullptr, + errors::Internal("No per-step resource manager.")); + Stack* stack = new Stack(elem_type_, stack_handle); + OP_REQUIRES_OK(ctx, rm->Create(handle(0), stack_name_, stack)); + ctx->set_output_ref(0, stack->mu(), stack->handle()); + } + + private: + DataType elem_type_; + string stack_name_; + + TF_DISALLOW_COPY_AND_ASSIGN(StackOp); +}; + +REGISTER_KERNEL_BUILDER(Name("Stack").Device(DEVICE_CPU), StackOp); +REGISTER_KERNEL_BUILDER(Name("Stack").Device(DEVICE_GPU).HostMemory("handle"), + StackOp); + +class StackPushOp : public OpKernel { + public: + explicit StackPushOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* ctx) override { + Tensor Tstack_handle = ctx->mutable_input(0, false); + OP_REQUIRES(ctx, Tstack_handle.NumElements() == 2, + errors::InvalidArgument( + "Stack handle must have two elements, but had shape: ", + Tstack_handle.shape().DebugString())); + const string& container = Tstack_handle.flat<string>()(0); + const string& stack_name = Tstack_handle.flat<string>()(1); + ResourceMgr* rm = ctx->step_resource_manager(); + OP_REQUIRES(ctx, rm != nullptr, + errors::Internal("No per-step resource manager.")); + Stack* stack = nullptr; + OP_REQUIRES_OK(ctx, rm->Lookup(container, stack_name, &stack)); + OP_REQUIRES(ctx, ctx->input_dtype(1) == stack->ElemType(), + errors::InvalidArgument("Must have type ", stack->ElemType(), + " but got ", ctx->input_dtype(1))); + stack->Push(PersistentTensor(ctx->input(1))); + ctx->set_output(0, ctx->input(1)); + } +}; + +REGISTER_KERNEL_BUILDER(Name("StackPush").Device(DEVICE_CPU), StackPushOp); +REGISTER_KERNEL_BUILDER( + Name("StackPush").Device(DEVICE_GPU).HostMemory("handle"), StackPushOp); + +class StackPopOp : public OpKernel { + public: + explicit StackPopOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* ctx) override { + Tensor Tstack_handle = ctx->mutable_input(0, false); + OP_REQUIRES(ctx, Tstack_handle.NumElements() == 2, + errors::InvalidArgument( + "Stack handle must have two elements, but had shape: ", + Tstack_handle.shape().DebugString())); + const string& container = Tstack_handle.flat<string>()(0); + const string& stack_name = Tstack_handle.flat<string>()(1); + ResourceMgr* rm = ctx->step_resource_manager(); + OP_REQUIRES(ctx, rm != nullptr, + errors::Internal("No per-step resource manager.")); + Stack* stack = nullptr; + OP_REQUIRES_OK(ctx, rm->Lookup(container, stack_name, &stack)); + PersistentTensor value; + bool has_value = stack->Pop(&value); + if (!has_value) { + errors::InvalidArgument("Calling Pop() when the stack is empty."); + } + ctx->set_output(0, *value.AccessTensor(ctx)); + } +}; + +REGISTER_KERNEL_BUILDER(Name("StackPop").Device(DEVICE_CPU), StackPopOp); +REGISTER_KERNEL_BUILDER( + Name("StackPop").Device(DEVICE_GPU).HostMemory("handle"), StackPopOp); + +class StackCloseOp : public OpKernel { + public: + explicit StackCloseOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* ctx) override { + Tensor Tstack_handle = ctx->mutable_input(0, false); + OP_REQUIRES(ctx, Tstack_handle.NumElements() == 2, + errors::InvalidArgument( + "Stack handle must have two elements, but had shape: ", + Tstack_handle.shape().DebugString())); + const string& container = Tstack_handle.flat<string>()(0); + const string& stack_name = Tstack_handle.flat<string>()(1); + ResourceMgr* rm = ctx->step_resource_manager(); + OP_REQUIRES(ctx, rm != nullptr, + errors::Internal("No per-step resource manager.")); + OP_REQUIRES_OK(ctx, rm->Delete<Stack>(container, stack_name)); + } +}; + +REGISTER_KERNEL_BUILDER(Name("StackClose").Device(DEVICE_CPU), StackCloseOp); +REGISTER_KERNEL_BUILDER( + Name("StackClose").Device(DEVICE_GPU).HostMemory("handle"), StackCloseOp); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/string_to_hash_bucket_op.cc b/tensorflow/core/kernels/string_to_hash_bucket_op.cc index 088d6a0e5f0..2bb9626381a 100644 --- a/tensorflow/core/kernels/string_to_hash_bucket_op.cc +++ b/tensorflow/core/kernels/string_to_hash_bucket_op.cc @@ -40,7 +40,8 @@ class StringToHashBucketOp : public OpKernel { &output_tensor)); auto output_flat = output_tensor->flat<int64>(); - for (size_t i = 0; i < input_flat.size(); ++i) { + typedef decltype(input_flat.size()) Index; + for (Index i = 0; i < input_flat.size(); ++i) { const uint64 input_hash = Hash64(input_flat(i)); const uint64 bucket_id = input_hash % num_buckets_; // The number of buckets is always in the positive range of int64 so is diff --git a/tensorflow/core/lib/io/block.cc b/tensorflow/core/lib/io/block.cc index 69c62a93472..8eb4a882b60 100644 --- a/tensorflow/core/lib/io/block.cc +++ b/tensorflow/core/lib/io/block.cc @@ -1,7 +1,18 @@ -// Copyright (c) 2011 The LevelDB Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. See the AUTHORS file for names of contributors. -// +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + // Decodes the blocks generated by block_builder.cc. #include "tensorflow/core/lib/io/block.h" diff --git a/tensorflow/core/lib/io/block.h b/tensorflow/core/lib/io/block.h index bf53245b8d1..01917fac5a5 100644 --- a/tensorflow/core/lib/io/block.h +++ b/tensorflow/core/lib/io/block.h @@ -1,6 +1,17 @@ -// Copyright (c) 2011 The LevelDB Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. See the AUTHORS file for names of contributors. +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ #ifndef TENSORFLOW_LIB_IO_BLOCK_H_ #define TENSORFLOW_LIB_IO_BLOCK_H_ diff --git a/tensorflow/core/lib/io/block_builder.cc b/tensorflow/core/lib/io/block_builder.cc index 0c9671dcb4b..7c6c4190fb4 100644 --- a/tensorflow/core/lib/io/block_builder.cc +++ b/tensorflow/core/lib/io/block_builder.cc @@ -1,7 +1,18 @@ -// Copyright (c) 2011 The LevelDB Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. See the AUTHORS file for names of contributors. -// +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + // BlockBuilder generates blocks where keys are prefix-compressed: // // When we store a key, we drop the prefix shared with the previous diff --git a/tensorflow/core/lib/io/block_builder.h b/tensorflow/core/lib/io/block_builder.h index e07a6478051..3ca67db7ed9 100644 --- a/tensorflow/core/lib/io/block_builder.h +++ b/tensorflow/core/lib/io/block_builder.h @@ -1,6 +1,17 @@ -// Copyright (c) 2011 The LevelDB Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. See the AUTHORS file for names of contributors. +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ #ifndef TENSORFLOW_LIB_IO_BLOCK_BUILDER_H_ #define TENSORFLOW_LIB_IO_BLOCK_BUILDER_H_ diff --git a/tensorflow/core/lib/io/format.cc b/tensorflow/core/lib/io/format.cc index fc43d8aeead..ccce24d8007 100644 --- a/tensorflow/core/lib/io/format.cc +++ b/tensorflow/core/lib/io/format.cc @@ -1,6 +1,17 @@ -// Copyright (c) 2011 The LevelDB Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. See the AUTHORS file for names of contributors. +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ #include "tensorflow/core/lib/io/format.h" diff --git a/tensorflow/core/lib/io/format.h b/tensorflow/core/lib/io/format.h index b9eeef65ab2..16f3c59fdae 100644 --- a/tensorflow/core/lib/io/format.h +++ b/tensorflow/core/lib/io/format.h @@ -1,6 +1,17 @@ -// Copyright (c) 2011 The LevelDB Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. See the AUTHORS file for names of contributors. +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ #ifndef TENSORFLOW_LIB_IO_FORMAT_H_ #define TENSORFLOW_LIB_IO_FORMAT_H_ diff --git a/tensorflow/core/lib/io/iterator.cc b/tensorflow/core/lib/io/iterator.cc index 878e93a9114..02e0b879418 100644 --- a/tensorflow/core/lib/io/iterator.cc +++ b/tensorflow/core/lib/io/iterator.cc @@ -1,6 +1,17 @@ -// Copyright (c) 2011 The LevelDB Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. See the AUTHORS file for names of contributors. +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ #include "tensorflow/core/lib/io/iterator.h" diff --git a/tensorflow/core/lib/io/iterator.h b/tensorflow/core/lib/io/iterator.h index 603a2f95fe6..5b8d46909c0 100644 --- a/tensorflow/core/lib/io/iterator.h +++ b/tensorflow/core/lib/io/iterator.h @@ -1,7 +1,18 @@ -// Copyright (c) 2011 The LevelDB Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. See the AUTHORS file for names of contributors. -// +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + // An iterator yields a sequence of key/value pairs from a source. // The following class defines the interface. Multiple implementations // are provided by this library. In particular, iterators are provided diff --git a/tensorflow/core/lib/io/table.cc b/tensorflow/core/lib/io/table.cc index 8bdb80e4d2d..4b7aa0c03e5 100644 --- a/tensorflow/core/lib/io/table.cc +++ b/tensorflow/core/lib/io/table.cc @@ -1,6 +1,17 @@ -// Copyright (c) 2011 The LevelDB Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. See the AUTHORS file for names of contributors. +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ #include "tensorflow/core/lib/io/table.h" diff --git a/tensorflow/core/lib/io/table.h b/tensorflow/core/lib/io/table.h index 6f3d52ad41b..27ef0351dc8 100644 --- a/tensorflow/core/lib/io/table.h +++ b/tensorflow/core/lib/io/table.h @@ -1,6 +1,17 @@ -// Copyright (c) 2011 The LevelDB Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. See the AUTHORS file for names of contributors. +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ #ifndef TENSORFLOW_LIB_IO_TABLE_H_ #define TENSORFLOW_LIB_IO_TABLE_H_ diff --git a/tensorflow/core/lib/io/table_builder.cc b/tensorflow/core/lib/io/table_builder.cc index d19bc64ac06..f2b3e49b875 100644 --- a/tensorflow/core/lib/io/table_builder.cc +++ b/tensorflow/core/lib/io/table_builder.cc @@ -1,6 +1,17 @@ -// Copyright (c) 2011 The LevelDB Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. See the AUTHORS file for names of contributors. +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ #include "tensorflow/core/lib/io/table_builder.h" diff --git a/tensorflow/core/lib/io/table_builder.h b/tensorflow/core/lib/io/table_builder.h index cebf4d8e0c0..de940167a16 100644 --- a/tensorflow/core/lib/io/table_builder.h +++ b/tensorflow/core/lib/io/table_builder.h @@ -1,7 +1,18 @@ -// Copyright (c) 2011 The LevelDB Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. See the AUTHORS file for names of contributors. -// +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + // TableBuilder provides the interface used to build a Table // (an immutable and sorted map from keys to values). // diff --git a/tensorflow/core/lib/io/table_format.txt b/tensorflow/core/lib/io/table_format.txt index 7edb9fb1213..e37c627f5b1 100644 --- a/tensorflow/core/lib/io/table_format.txt +++ b/tensorflow/core/lib/io/table_format.txt @@ -1,8 +1,8 @@ File format =========== -The table format is heavily based on the table format for the LevelDB +The table format is similar to the table format for the LevelDB open source key/value store, with the exception that our tables do not support "filter" meta blocks (Bloom Filters). See: -https://code.google.com/p/leveldb/source/browse/doc/table_format.txt +https://github.com/google/leveldb/blob/master/doc/table_format.txt diff --git a/tensorflow/core/lib/io/table_test.cc b/tensorflow/core/lib/io/table_test.cc index 3c5a9ecbc1e..b53bdf5607f 100644 --- a/tensorflow/core/lib/io/table_test.cc +++ b/tensorflow/core/lib/io/table_test.cc @@ -1,6 +1,17 @@ -// Copyright (c) 2011 The LevelDB Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. See the AUTHORS file for names of contributors. +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ #include "tensorflow/core/lib/io/table.h" diff --git a/tensorflow/core/lib/io/two_level_iterator.cc b/tensorflow/core/lib/io/two_level_iterator.cc index b71d1077074..640bcddbe15 100644 --- a/tensorflow/core/lib/io/two_level_iterator.cc +++ b/tensorflow/core/lib/io/two_level_iterator.cc @@ -1,6 +1,17 @@ -// Copyright (c) 2011 The LevelDB Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. See the AUTHORS file for names of contributors. +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ #include "tensorflow/core/lib/io/two_level_iterator.h" diff --git a/tensorflow/core/lib/io/two_level_iterator.h b/tensorflow/core/lib/io/two_level_iterator.h index 1cc5d2f9214..4b89c63404f 100644 --- a/tensorflow/core/lib/io/two_level_iterator.h +++ b/tensorflow/core/lib/io/two_level_iterator.h @@ -1,6 +1,17 @@ -// Copyright (c) 2011 The LevelDB Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. See the AUTHORS file for names of contributors. +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ #ifndef TENSORFLOW_LIB_IO_TWO_LEVEL_ITERATOR_H_ #define TENSORFLOW_LIB_IO_TWO_LEVEL_ITERATOR_H_ diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc index edefb1be3b2..59a78762899 100644 --- a/tensorflow/core/ops/data_flow_ops.cc +++ b/tensorflow/core/ops/data_flow_ops.cc @@ -301,6 +301,55 @@ size: The number of elements in the given queue. // -------------------------------------------------------------------------- +REGISTER_OP("Stack") + .Output("handle: Ref(string)") + .Attr("elem_type: type") + .Attr("stack_name: string = ''") + .SetIsStateful() + .Doc(R"doc( +A stack that produces elements in first-in last-out order. + +handle: The handle to the stack. +elem_type: The type of the elements on the stack. +stack_name: Overrides the name used for the temporary stack resource. Default +value is the name of the 'Stack' op (which is guaranteed unique). +)doc"); + +REGISTER_OP("StackPush") + .Input("handle: Ref(string)") + .Input("elem: T") + .Output("output: T") + .Attr("T: type") + .Doc(R"doc( +Push an element onto the stack. + +handle: The handle to a stack. +elem: The tensor to be pushed onto the stack. +output: The same tensor as the input 'elem'. +)doc"); + +REGISTER_OP("StackPop") + .Input("handle: Ref(string)") + .Output("elem: elem_type") + .Attr("elem_type: type") + .Doc(R"doc( +Pop the element at the top of the stack. + +handle: The handle to a stack. +elem_type: The type of the elem that is popped. +elem: The tensor that is popped from the top of the stack. +)doc"); + +REGISTER_OP("StackClose") + .Input("handle: Ref(string)") + .Doc(R"doc( +Delete the stack from its resource container. + +handle: The handle to a stack. +)doc"); + +// -------------------------------------------------------------------------- + REGISTER_OP("LookupTableFind") .Input("table_handle: Ref(string)") .Input("keys: Tin") diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc index 13b26e26b12..e57860bd570 100644 --- a/tensorflow/core/ops/image_ops.cc +++ b/tensorflow/core/ops/image_ops.cc @@ -89,6 +89,22 @@ resized_images: 4-D with shape `[batch, new_height, new_width, channels]`. )doc"); +// -------------------------------------------------------------------------- +REGISTER_OP("ResizeNearestNeighborGrad") + .Input("grads: T") + .Input("size: int32") + .Output("output: T") + .Attr("T: {uint8, int8, int32, float, double}") + .Doc(R"doc( +Computes the gradient of nearest neighbor interpolation. + +grads: 4-D with shape `[batch, height, width, channels]`. +size:= A 1-D int32 Tensor of 2 elements: `orig_height, orig_width`. The + original input size. +output: 4-D with shape `[batch, orig_height, orig_width, channels]`. Gradients + with respect to the input image. +)doc"); + // -------------------------------------------------------------------------- REGISTER_OP("RandomCrop") .Input("image: T") diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index 44e62cc4316..023c598aa61 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -655,7 +655,7 @@ that `segment_ids[j] == i`. segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s first dimension. Values should be sorted and can be repeated. -output: Has same shape as data, except for dimension_0 which +output: Has same shape as data, except for dimension 0 which has size `k`, the number of segments. )doc"); @@ -684,7 +684,7 @@ values summed. segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s first dimension. Values should be sorted and can be repeated. -output: Has same shape as data, except for dimension_0 which +output: Has same shape as data, except for dimension 0 which has size `k`, the number of segments. )doc"); @@ -712,7 +712,7 @@ that `segment_ids[j] == i`. segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s first dimension. Values should be sorted and can be repeated. -output: Has same shape as data, except for dimension_0 which +output: Has same shape as data, except for dimension 0 which has size `k`, the number of segments. )doc"); @@ -740,7 +740,7 @@ that `segment_ids[j] == i`. segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s first dimension. Values should be sorted and can be repeated. -output: Has same shape as data, except for dimension_0 which +output: Has same shape as data, except for dimension 0 which has size `k`, the number of segments. )doc"); @@ -767,7 +767,7 @@ that `segment_ids[j] == i`. segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s first dimension. Values should be sorted and can be repeated. -output: Has same shape as data, except for dimension_0 which +output: Has same shape as data, except for dimension 0 which has size `k`, the number of segments. )doc"); @@ -802,7 +802,7 @@ If the sum is empty for a given segment ID `i`, `output[i] = 0`. segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s first dimension. -output: Has same shape as data, except for dimension_0 which +output: Has same shape as data, except for dimension 0 which has size `num_segments`. )doc"); @@ -821,7 +821,7 @@ Segmentation](../../api_docs/python/math_ops.md#segmentation) for an explanation of segments. Like `SegmentSum`, but `segment_ids` can have rank less than `data`'s first -dimension, selecting a subset of dimension_0, specified by `indices`. +dimension, selecting a subset of dimension 0, specified by `indices`. For example: @@ -850,7 +850,7 @@ indices: A 1-D tensor. Has same rank as `segment_ids`. segment_ids: A 1-D tensor. Values should be sorted and can be repeated. -output: Has same shape as data, except for dimension_0 which +output: Has same shape as data, except for dimension 0 which has size `k`, the number of segments. )doc"); @@ -868,13 +868,13 @@ Segmentation](../../api_docs/python/math_ops.md#segmentation) for an explanation of segments. Like `SegmentMean`, but `segment_ids` can have rank less than `data`'s first -dimension, selecting a subset of dimension_0, specified by `indices`. +dimension, selecting a subset of dimension 0, specified by `indices`. indices: A 1-D tensor. Has same rank as `segment_ids`. segment_ids: A 1-D tensor. Values should be sorted and can be repeated. -output: Has same shape as data, except for dimension_0 which +output: Has same shape as data, except for dimension 0 which has size `k`, the number of segments. )doc"); @@ -889,13 +889,13 @@ REGISTER_OP("SparseSegmentMeanGrad") .Doc(R"doc( Computes gradients for SparseSegmentMean. -Returns tensor "output" with same shape as grad, except for dimension_0 whose +Returns tensor "output" with same shape as grad, except for dimension 0 whose value is output_dim0. grad: gradient propagated to the SparseSegmentMean op. indices: indices passed to the corresponding SparseSegmentMean op. segment_ids: segment_ids passed to the corresponding SparseSegmentMean op. -output_dim0: dimension_0 of "data" passed to SparseSegmentMean op. +output_dim0: dimension 0 of "data" passed to SparseSegmentMean op. )doc"); REGISTER_OP("All") diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 68671e9e957..33875345873 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -1874,7 +1874,7 @@ op { } output_arg { name: "output" - description: "A Tensor with one more dimension than the input bytes. The\nadded dimension will have size equal to the length of the elements\nof bytes divided by the number of bytes to represent out_type." + description: "A Tensor with one more dimension than the input `bytes`. The\nadded dimension will have size equal to the length of the elements\nof `bytes` divided by the number of bytes to represent `out_type`." type_attr: "out_type" } attr { @@ -1898,7 +1898,7 @@ op { default_value { b: true } - description: "Whether the input bytes are in little-endian order.\nIgnored for out_types that are stored in a single byte like uint8." + description: "Whether the input `bytes` are in little-endian order.\nIgnored for `out_type` values that are stored in a single byte like\n`uint8`." } summary: "Reinterpret the bytes of a string as a vector of numbers." } @@ -5666,6 +5666,38 @@ op { summary: "Resize `images` to `size` using nearest neighbor interpolation." description: "Input images can be of different types but output images are always float." } +op { + name: "ResizeNearestNeighborGrad" + input_arg { + name: "grads" + description: "4-D with shape `[batch, height, width, channels]`." + type_attr: "T" + } + input_arg { + name: "size" + description: "= A 1-D int32 Tensor of 2 elements: `orig_height, orig_width`. The\noriginal input size." + type: DT_INT32 + } + output_arg { + name: "output" + description: "4-D with shape `[batch, orig_height, orig_width, channels]`. Gradients\nwith respect to the input image." + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_UINT8 + type: DT_INT8 + type: DT_INT32 + type: DT_FLOAT + type: DT_DOUBLE + } + } + } + summary: "Computes the gradient of nearest neighbor interpolation." +} op { name: "Restore" input_arg { @@ -6125,7 +6157,7 @@ op { } output_arg { name: "output" - description: "Has same shape as data, except for dimension_0 which\nhas size `k`, the number of segments." + description: "Has same shape as data, except for dimension 0 which\nhas size `k`, the number of segments." type_attr: "T" } attr { @@ -6169,7 +6201,7 @@ op { } output_arg { name: "output" - description: "Has same shape as data, except for dimension_0 which\nhas size `k`, the number of segments." + description: "Has same shape as data, except for dimension 0 which\nhas size `k`, the number of segments." type_attr: "T" } attr { @@ -6213,7 +6245,7 @@ op { } output_arg { name: "output" - description: "Has same shape as data, except for dimension_0 which\nhas size `k`, the number of segments." + description: "Has same shape as data, except for dimension 0 which\nhas size `k`, the number of segments." type_attr: "T" } attr { @@ -6257,7 +6289,7 @@ op { } output_arg { name: "output" - description: "Has same shape as data, except for dimension_0 which\nhas size `k`, the number of segments." + description: "Has same shape as data, except for dimension 0 which\nhas size `k`, the number of segments." type_attr: "T" } attr { @@ -6301,7 +6333,7 @@ op { } output_arg { name: "output" - description: "Has same shape as data, except for dimension_0 which\nhas size `k`, the number of segments." + description: "Has same shape as data, except for dimension 0 which\nhas size `k`, the number of segments." type_attr: "T" } attr { @@ -7044,7 +7076,7 @@ op { } output_arg { name: "output" - description: "Has same shape as data, except for dimension_0 which\nhas size `k`, the number of segments." + description: "Has same shape as data, except for dimension 0 which\nhas size `k`, the number of segments." type_attr: "T" } attr { @@ -7058,7 +7090,7 @@ op { } } summary: "Computes the mean along sparse segments of a tensor." - description: "Read [the section on\nSegmentation](../../api_docs/python/math_ops.md#segmentation) for an explanation\nof segments.\n\nLike `SegmentMean`, but `segment_ids` can have rank less than `data`\'s first\ndimension, selecting a subset of dimension_0, specified by `indices`." + description: "Read [the section on\nSegmentation](../../api_docs/python/math_ops.md#segmentation) for an explanation\nof segments.\n\nLike `SegmentMean`, but `segment_ids` can have rank less than `data`\'s first\ndimension, selecting a subset of dimension 0, specified by `indices`." } op { name: "SparseSegmentMeanGrad" @@ -7079,7 +7111,7 @@ op { } input_arg { name: "output_dim0" - description: "dimension_0 of \"data\" passed to SparseSegmentMean op." + description: "dimension 0 of \"data\" passed to SparseSegmentMean op." type: DT_INT32 } output_arg { @@ -7097,7 +7129,7 @@ op { } } summary: "Computes gradients for SparseSegmentMean." - description: "Returns tensor \"output\" with same shape as grad, except for dimension_0 whose\nvalue is output_dim0." + description: "Returns tensor \"output\" with same shape as grad, except for dimension 0 whose\nvalue is output_dim0." } op { name: "SparseSegmentSum" @@ -7117,7 +7149,7 @@ op { } output_arg { name: "output" - description: "Has same shape as data, except for dimension_0 which\nhas size `k`, the number of segments." + description: "Has same shape as data, except for dimension 0 which\nhas size `k`, the number of segments." type_attr: "T" } attr { @@ -7136,7 +7168,7 @@ op { } } summary: "Computes the sum along sparse segments of a tensor." - description: "Read [the section on\nSegmentation](../../api_docs/python/math_ops.md#segmentation) for an explanation\nof segments.\n\nLike `SegmentSum`, but `segment_ids` can have rank less than `data`\'s first\ndimension, selecting a subset of dimension_0, specified by `indices`.\n\nFor example:\n\n```prettyprint\nc = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]])\n\n# Select two rows, one segment.\ntf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 0]))\n ==> [[0 0 0 0]]\n\n# Select two rows, two segment.\ntf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 1]))\n ==> [[ 1 2 3 4]\n [-1 -2 -3 -4]]\n\n# Select all rows, two segments.\ntf.sparse_segment_sum(c, tf.constant([0, 1, 2]), tf.constant([0, 0, 1]))\n ==> [[0 0 0 0]\n [5 6 7 8]]\n\n# Which is equivalent to:\ntf.segment_sum(c, tf.constant([0, 0, 1]))\n```" + description: "Read [the section on\nSegmentation](../../api_docs/python/math_ops.md#segmentation) for an explanation\nof segments.\n\nLike `SegmentSum`, but `segment_ids` can have rank less than `data`\'s first\ndimension, selecting a subset of dimension 0, specified by `indices`.\n\nFor example:\n\n```prettyprint\nc = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]])\n\n# Select two rows, one segment.\ntf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 0]))\n ==> [[0 0 0 0]]\n\n# Select two rows, two segment.\ntf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 1]))\n ==> [[ 1 2 3 4]\n [-1 -2 -3 -4]]\n\n# Select all rows, two segments.\ntf.sparse_segment_sum(c, tf.constant([0, 1, 2]), tf.constant([0, 0, 1]))\n ==> [[0 0 0 0]\n [5 6 7 8]]\n\n# Which is equivalent to:\ntf.segment_sum(c, tf.constant([0, 0, 1]))\n```" } op { name: "SparseToDense" @@ -7294,6 +7326,84 @@ op { summary: "Removes dimensions of size 1 from the shape of a tensor." description: "Given a tensor `input`, this operation returns a tensor of the same type with\nall dimensions of size 1 removed. If you don\'t want to remove all size 1\ndimensions, you can remove specific size 1 dimensions by specifying\n`squeeze_dims`.\n\nFor example:\n\n```prettyprint\n# \'t\' is a tensor of shape [1, 2, 1, 3, 1, 1]\nshape(squeeze(t)) ==> [2, 3]\n```\n\nOr, to remove specific size 1 dimensions:\n\n```prettyprint\n# \'t\' is a tensor of shape [1, 2, 1, 3, 1, 1]\nshape(squeeze(t, [2, 4])) ==> [1, 2, 3, 1]\n```" } +op { + name: "Stack" + output_arg { + name: "handle" + description: "The handle to the stack." + type: DT_STRING + is_ref: true + } + attr { + name: "elem_type" + type: "type" + description: "The type of the elements on the stack." + } + attr { + name: "stack_name" + type: "string" + default_value { + s: "" + } + description: "Overrides the name used for the temporary stack resource. Default\nvalue is the name of the \'Stack\' op (which is guaranteed unique)." + } + summary: "A stack that produces elements in first-in last-out order." + is_stateful: true +} +op { + name: "StackClose" + input_arg { + name: "handle" + description: "The handle to a stack." + type: DT_STRING + is_ref: true + } + summary: "Delete the stack from its resource container." +} +op { + name: "StackPop" + input_arg { + name: "handle" + description: "The handle to a stack." + type: DT_STRING + is_ref: true + } + output_arg { + name: "elem" + description: "The tensor that is popped from the top of the stack." + type_attr: "elem_type" + } + attr { + name: "elem_type" + type: "type" + description: "The type of the elem that is popped." + } + summary: "Pop the element at the top of the stack." +} +op { + name: "StackPush" + input_arg { + name: "handle" + description: "The handle to a stack." + type: DT_STRING + is_ref: true + } + input_arg { + name: "elem" + description: "The tensor to be pushed onto the stack." + type_attr: "T" + } + output_arg { + name: "output" + description: "The same tensor as the input \'elem\'." + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + summary: "Push an element onto the stack." +} op { name: "StopGradient" input_arg { @@ -7319,7 +7429,7 @@ op { } output_arg { name: "output" - description: "A Tensor of the same shape as the input string_tensor." + description: "A Tensor of the same shape as the input `string_tensor`." type: DT_INT64 } attr { @@ -7340,7 +7450,7 @@ op { } output_arg { name: "output" - description: "A Tensor of the same shape as the input string_tensor." + description: "A Tensor of the same shape as the input `string_tensor`." type_attr: "out_type" } attr { @@ -7942,7 +8052,7 @@ op { } output_arg { name: "output" - description: "Has same shape as data, except for dimension_0 which\nhas size `num_segments`." + description: "Has same shape as data, except for dimension 0 which\nhas size `num_segments`." type_attr: "T" } attr { diff --git a/tensorflow/core/ops/parsing_ops.cc b/tensorflow/core/ops/parsing_ops.cc index 9d7c57d55bc..d1fd86abbf5 100644 --- a/tensorflow/core/ops/parsing_ops.cc +++ b/tensorflow/core/ops/parsing_ops.cc @@ -26,11 +26,12 @@ REGISTER_OP("DecodeRaw") Reinterpret the bytes of a string as a vector of numbers. bytes: All the elements must have the same length. -little_endian: Whether the input bytes are in little-endian order. - Ignored for out_types that are stored in a single byte like uint8. -output: A Tensor with one more dimension than the input bytes. The +little_endian: Whether the input `bytes` are in little-endian order. + Ignored for `out_type` values that are stored in a single byte like + `uint8`. +output: A Tensor with one more dimension than the input `bytes`. The added dimension will have size equal to the length of the elements - of bytes divided by the number of bytes to represent out_type. + of `bytes` divided by the number of bytes to represent `out_type`. )doc"); REGISTER_OP("ParseExample") @@ -113,7 +114,7 @@ Converts each string in the input Tensor to the specified numeric type. results in a rounded value.) out_type: The numeric type to interpret each string in string_tensor as. -output: A Tensor of the same shape as the input string_tensor. +output: A Tensor of the same shape as the input `string_tensor`. )doc"); } // namespace tensorflow diff --git a/tensorflow/core/ops/string_ops.cc b/tensorflow/core/ops/string_ops.cc index bfb443fe0a3..d621c5368f5 100644 --- a/tensorflow/core/ops/string_ops.cc +++ b/tensorflow/core/ops/string_ops.cc @@ -30,7 +30,7 @@ process. Note that the hash function may change from time to time. num_buckets: The number of buckets. -output: A Tensor of the same shape as the input string_tensor. +output: A Tensor of the same shape as the input `string_tensor`. )doc"); } // namespace tensorflow diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl index 7cf6c274be5..af25ef3f353 100644 --- a/tensorflow/core/platform/default/build_config.bzl +++ b/tensorflow/core/platform/default/build_config.bzl @@ -38,6 +38,7 @@ def tf_proto_library(name, srcs = [], has_services = False, py_proto_library(name=name + "_py", srcs=srcs + tf_deps(deps, "_proto_srcs"), + srcs_version="PY2AND3", deps=deps, py_libs = ["//google/protobuf:protobuf_python"], testonly=testonly, @@ -46,6 +47,7 @@ def tf_proto_library(name, srcs = [], has_services = False, def tf_proto_library_py(name, srcs=[], deps=[], visibility=[], testonly=0): py_proto_library(name = name + "_py", srcs = srcs, + srcs_version = "PY2AND3", deps = deps, visibility = visibility, testonly = testonly) diff --git a/tensorflow/core/platform/default/thread_annotations.h b/tensorflow/core/platform/default/thread_annotations.h index fed39bf810f..35a196d8407 100644 --- a/tensorflow/core/platform/default/thread_annotations.h +++ b/tensorflow/core/platform/default/thread_annotations.h @@ -1,33 +1,18 @@ -// Copyright (c) 2008, Google Inc. -// All rights reserved. -// -// Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are -// met: -// -// * Redistributions of source code must retain the above copyright -// notice, this list of conditions and the following disclaimer. -// * Redistributions in binary form must reproduce the above -// copyright notice, this list of conditions and the following disclaimer -// in the documentation and/or other materials provided with the -// distribution. -// * Neither the name of Google Inc. nor the names of its -// contributors may be used to endorse or promote products derived from -// this software without specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -// --- -// +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + // This header file contains the macro definitions for thread safety // annotations that allow the developers to document the locking policies // of their multi-threaded code. The annotations can also help program diff --git a/tensorflow/core/public/env.h b/tensorflow/core/public/env.h index b3db96eb3ed..e3ecd911d3c 100644 --- a/tensorflow/core/public/env.h +++ b/tensorflow/core/public/env.h @@ -107,7 +107,7 @@ class Env { /// Deletes the specified directory. virtual Status DeleteDir(const string& dirname) = 0; - /// Stores the size of fname in *file_size. + /// Stores the size of `fname` in `*file_size`. virtual Status GetFileSize(const string& fname, uint64* file_size) = 0; /// \brief Renames file src to target. If target already exists, it will be @@ -146,18 +146,18 @@ class RandomAccessFile { RandomAccessFile() {} virtual ~RandomAccessFile(); - /// \brief Reads up to "n" bytes from the file starting at "offset". + /// \brief Reads up to `n` bytes from the file starting at `offset`. /// - /// "scratch[0..n-1]" may be written by this routine. Sets "*result" - /// to the data that was read (including if fewer than "n" bytes were - /// successfully read). May set "*result" to point at data in - /// "scratch[0..n-1]", so "scratch[0..n-1]" must be live when - /// "*result" is used. + /// `scratch[0..n-1]` may be written by this routine. Sets `*result` + /// to the data that was read (including if fewer than `n` bytes were + /// successfully read). May set `*result` to point at data in + /// `scratch[0..n-1]`, so `scratch[0..n-1]` must be live when + /// `*result` is used. /// - /// On OK returned status: "n" bytes have been stored in "*result". - /// On non-OK returned status: [0..n] bytes have been stored in "*result". + /// On OK returned status: `n` bytes have been stored in `*result`. + /// On non-OK returned status: `[0..n]` bytes have been stored in `*result`. /// - /// Returns `OUT_OF_RANGE` if fewer than n bytes were stored in "*result" + /// Returns `OUT_OF_RANGE` if fewer than n bytes were stored in `*result` /// because of EOF. /// /// Safe for concurrent use by multiple threads. @@ -263,16 +263,16 @@ struct ThreadOptions { size_t guard_size = 0; // 0: use system default value }; -/// A utility routine: reads contents of named file into *data +/// A utility routine: reads contents of named file into `*data` Status ReadFileToString(Env* env, const string& fname, string* data); -/// A utility routine: write contents of "data" to file named "fname" +/// A utility routine: write contents of `data` to file named `fname` /// (overwriting existing contents, if any). Status WriteStringToFile(Env* env, const string& fname, const StringPiece& data); /// Reads contents of named file and parse as binary encoded proto data -/// and store into *proto. +/// and store into `*proto`. Status ReadBinaryProto(Env* env, const string& fname, ::tensorflow::protobuf::MessageLite* proto); diff --git a/tensorflow/core/util/cuda_kernel_helper.h b/tensorflow/core/util/cuda_kernel_helper.h index 22a6e5dcef8..86f974ee26b 100644 --- a/tensorflow/core/util/cuda_kernel_helper.h +++ b/tensorflow/core/util/cuda_kernel_helper.h @@ -60,6 +60,15 @@ inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count, return config; } +template <typename T> +__device__ __host__ inline T ldg(const T* address) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350 + return __ldg(address); +#else + return *address; +#endif +} + } // namespace tensorflow #endif // GOOGLE_CUDA diff --git a/tensorflow/core/util/events_writer.cc b/tensorflow/core/util/events_writer.cc index b210f42879e..7ea03be5fd7 100644 --- a/tensorflow/core/util/events_writer.cc +++ b/tensorflow/core/util/events_writer.cc @@ -90,7 +90,7 @@ string EventsWriter::FileName() { return filename_; } -void EventsWriter::WriteSerializedEvent(const string& event_str) { +void EventsWriter::WriteSerializedEvent(StringPiece event_str) { if (recordio_writer_.get() == NULL) { if (!Init()) { LOG(ERROR) << "Write failed because file could not be opened."; diff --git a/tensorflow/core/util/events_writer.h b/tensorflow/core/util/events_writer.h index 59488ef407c..14f2594a28f 100644 --- a/tensorflow/core/util/events_writer.h +++ b/tensorflow/core/util/events_writer.h @@ -61,8 +61,8 @@ class EventsWriter { // Append "event_str", a serialized Event, to the file. // Note that this function does NOT check that de-serializing event_str - // results in a valid Event proto. - void WriteSerializedEvent(const string& event_str); + // results in a valid Event proto. The tensorflow:: bit makes SWIG happy. + void WriteSerializedEvent(tensorflow::StringPiece event_str); // EventWriter automatically flushes and closes on destruction, but // these two methods are provided for users who want to write to disk sooner diff --git a/tensorflow/examples/android/BUILD b/tensorflow/examples/android/BUILD index fb5bc8da71d..872aeeb6a2f 100644 --- a/tensorflow/examples/android/BUILD +++ b/tensorflow/examples/android/BUILD @@ -14,6 +14,7 @@ cc_library( copts = [ "-std=c++11", "-mfpu=neon", + "-O2", ], linkopts = ["-llog -landroid -lm -ljnigraphics"], tags = [ diff --git a/tensorflow/examples/android/README.md b/tensorflow/examples/android/README.md index 5e24984b494..b4bab119486 100644 --- a/tensorflow/examples/android/README.md +++ b/tensorflow/examples/android/README.md @@ -36,13 +36,9 @@ Then, after editing your WORKSPACE file, you must build the APK. Run this from your workspace root: ```bash -$ bazel build //tensorflow/examples/android:tensorflow_demo -c opt --copt=-mfpu=neon +$ bazel build //tensorflow/examples/android:tensorflow_demo ``` -Note that `-c opt` is currently required; if not set, an assert (for an -otherwise non-problematic issue) in Eigen will halt the application during -execution. This issue will be corrected in an upcoming release. - If adb debugging is enabled on your Android 5.0 or later device, you may then use the following command from your workspace root to install the APK once built: @@ -55,7 +51,7 @@ Alternatively, a streamlined means of building, installing and running in one command is: ```bash -$ bazel mobile-install //tensorflow/examples/android:tensorflow_demo -c opt --start_app --copt=-mfpu=neon +$ bazel mobile-install //tensorflow/examples/android:tensorflow_demo --start_app ``` If camera permission errors are encountered (possible on Android Marshmallow or diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/CameraConnectionFragment.java b/tensorflow/examples/android/src/org/tensorflow/demo/CameraConnectionFragment.java index 28dadb9ddf5..9c53021636d 100644 --- a/tensorflow/examples/android/src/org/tensorflow/demo/CameraConnectionFragment.java +++ b/tensorflow/examples/android/src/org/tensorflow/demo/CameraConnectionFragment.java @@ -231,7 +231,7 @@ public class CameraConnectionFragment extends Fragment { private static Size chooseOptimalSize( final Size[] choices, final int width, final int height, final Size aspectRatio) { // Collect the supported resolutions that are at least as big as the preview Surface - final List<Size> bigEnough = new ArrayList<>(); + final List<Size> bigEnough = new ArrayList<Size>(); for (final Size option : choices) { if (option.getHeight() >= MINIMUM_PREVIEW_SIZE && option.getWidth() >= MINIMUM_PREVIEW_SIZE) { LOGGER.i("Adding size: " + option.getWidth() + "x" + option.getHeight()); diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/Classifier.java b/tensorflow/examples/android/src/org/tensorflow/demo/Classifier.java index 60b3037c7d1..8ae96843e19 100644 --- a/tensorflow/examples/android/src/org/tensorflow/demo/Classifier.java +++ b/tensorflow/examples/android/src/org/tensorflow/demo/Classifier.java @@ -1,3 +1,18 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + package org.tensorflow.demo; import android.graphics.Bitmap; diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/RecognitionScoreView.java b/tensorflow/examples/android/src/org/tensorflow/demo/RecognitionScoreView.java index 961b492a8d4..c26446aa610 100644 --- a/tensorflow/examples/android/src/org/tensorflow/demo/RecognitionScoreView.java +++ b/tensorflow/examples/android/src/org/tensorflow/demo/RecognitionScoreView.java @@ -1,3 +1,18 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + package org.tensorflow.demo; import android.content.Context; diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/TensorflowClassifier.java b/tensorflow/examples/android/src/org/tensorflow/demo/TensorflowClassifier.java index 84a7596ecbd..c4e58634df2 100644 --- a/tensorflow/examples/android/src/org/tensorflow/demo/TensorflowClassifier.java +++ b/tensorflow/examples/android/src/org/tensorflow/demo/TensorflowClassifier.java @@ -1,3 +1,18 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + package org.tensorflow.demo; import android.content.res.AssetManager; diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/TensorflowImageListener.java b/tensorflow/examples/android/src/org/tensorflow/demo/TensorflowImageListener.java index 940fbc6771b..0deb7c151eb 100644 --- a/tensorflow/examples/android/src/org/tensorflow/demo/TensorflowImageListener.java +++ b/tensorflow/examples/android/src/org/tensorflow/demo/TensorflowImageListener.java @@ -1,3 +1,18 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + package org.tensorflow.demo; import android.content.res.AssetManager; @@ -9,7 +24,6 @@ import android.media.Image; import android.media.Image.Plane; import android.media.ImageReader; import android.media.ImageReader.OnImageAvailableListener; - import junit.framework.Assert; import org.tensorflow.demo.env.ImageUtils; @@ -54,6 +68,44 @@ public class TensorflowImageListener implements OnImageAvailableListener { this.scoreView = scoreView; } + private void readPlanesToYuvBuffer(final Plane[] planes, final byte[] yuvBytes) { + int position = 0; + + // Copy the bytes from the Image into a buffer for easier conversion to RGB. + // TODO(andrewharp): Modify native code to accept multiple buffers so that + // only one pass is necessary during conversion to RGB. + final Plane yPlane = planes[0]; + final ByteBuffer yBuffer = yPlane.getBuffer(); + final int yRowStride = yPlane.getRowStride(); + + // Read the y (luminance buffer). + for (int row = 0; row < previewHeight; ++row) { + yBuffer.position(yRowStride * row); + + // Pixel stride is guaranteed to be 1 so we can + // just do a copy operation. + yBuffer.get(yuvBytes, position, previewWidth); + position += previewWidth; + } + + // Interleave the u and v buffers. + final ByteBuffer uBuffer = planes[1].getBuffer(); + final ByteBuffer vBuffer = planes[2].getBuffer(); + final int uvPixelStride = planes[1].getPixelStride(); + final int uvWidth = previewWidth / 2; + final int uvHeight = previewHeight / 2; + Assert.assertEquals( + planes[1].getRowStride(), planes[2].getRowStride()); + for (int y = 0; y < uvHeight; ++y) { + int readPos = planes[1].getRowStride() * y; + for (int x = 0; x < uvWidth; ++x) { + yuvBytes[position++] = vBuffer.get(readPos); + yuvBytes[position++] = uBuffer.get(readPos); + readPos += uvPixelStride; + } + } + } + private void drawResizedBitmap(final Bitmap src, final Bitmap dst) { Assert.assertEquals(dst.getWidth(), dst.getHeight()); final float minDim = Math.min(src.getWidth(), src.getHeight()); @@ -91,31 +143,17 @@ public class TensorflowImageListener implements OnImageAvailableListener { // Initialize the storage bitmaps once when the resolution is known. if (previewWidth != image.getWidth() || previewHeight != image.getHeight()) { - LOGGER.i("Initializing at size %dx%d", previewWidth, previewHeight); previewWidth = image.getWidth(); previewHeight = image.getHeight(); + + LOGGER.i("Initializing at size %dx%d", previewWidth, previewHeight); rgbBytes = new int[previewWidth * previewHeight]; yuvBytes = new byte[ImageUtils.getYUVByteSize(previewWidth, previewHeight)]; rgbFrameBitmap = Bitmap.createBitmap(previewWidth, previewHeight, Config.ARGB_8888); croppedBitmap = Bitmap.createBitmap(INPUT_SIZE, INPUT_SIZE, Config.ARGB_8888); } - final Plane[] planes = image.getPlanes(); - int position = 0; - - // Copy the bytes from the Image into a buffer for easier conversion to RGB. - // TODO(andrewharp): It may not be correct to do it this way. - final int[] planeOrder = {0, 2}; - for (int i = 0; i < planeOrder.length; ++i) { - final Plane plane = planes[planeOrder[i]]; - final ByteBuffer buffer = plane.getBuffer(); - - buffer.rewind(); - final int readAmount = buffer.remaining(); - - buffer.get(yuvBytes, position, readAmount); - position += readAmount; - } + readPlanesToYuvBuffer(image.getPlanes(), yuvBytes); image.close(); diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/env/ImageUtils.java b/tensorflow/examples/android/src/org/tensorflow/demo/env/ImageUtils.java index 78f818f7345..f9e564908b0 100644 --- a/tensorflow/examples/android/src/org/tensorflow/demo/env/ImageUtils.java +++ b/tensorflow/examples/android/src/org/tensorflow/demo/env/ImageUtils.java @@ -1,3 +1,18 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + package org.tensorflow.demo.env; import android.graphics.Bitmap; diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/env/Logger.java b/tensorflow/examples/android/src/org/tensorflow/demo/env/Logger.java index 697c2311766..fe6a83a0503 100644 --- a/tensorflow/examples/android/src/org/tensorflow/demo/env/Logger.java +++ b/tensorflow/examples/android/src/org/tensorflow/demo/env/Logger.java @@ -1,3 +1,18 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + package org.tensorflow.demo.env; import android.util.Log; diff --git a/tensorflow/examples/label_image/main.cc b/tensorflow/examples/label_image/main.cc index 00c1adefb77..c78ee33e06d 100644 --- a/tensorflow/examples/label_image/main.cc +++ b/tensorflow/examples/label_image/main.cc @@ -20,8 +20,8 @@ limitations under the License. // It's designed to have as few dependencies and be as clear as possible, so // it's more verbose than it could be in production code. In particular, using // auto for the types of a lot of the returned values from TensorFlow calls can -// remove a lot of boilerplate, but I find them explicit types useful in sample -// code to make it simple to look up the classes involved. +// remove a lot of boilerplate, but I find the explicit types useful in sample +// code to make it simple to look up the classes involved. // // To use it, compile and then run in a working directory with the // learning/brain/tutorials/label_image/data/ folder below it, and you should diff --git a/tensorflow/g3doc/api_docs/cc/ClassEnv.md b/tensorflow/g3doc/api_docs/cc/ClassEnv.md index 62bc69ada8f..38bb94ac635 100644 --- a/tensorflow/g3doc/api_docs/cc/ClassEnv.md +++ b/tensorflow/g3doc/api_docs/cc/ClassEnv.md @@ -27,7 +27,7 @@ All Env implementations are safe for concurrent access from multiple threads wit * [`virtual Status tensorflow::Env::DeleteDir(const string &dirname)=0`](#virtual_Status_tensorflow_Env_DeleteDir) * Deletes the specified directory. * [`virtual Status tensorflow::Env::GetFileSize(const string &fname, uint64 *file_size)=0`](#virtual_Status_tensorflow_Env_GetFileSize) - * Stores the size of fname in *file_size. + * Stores the size of `fname` in `*file_size`. * [`virtual Status tensorflow::Env::RenameFile(const string &src, const string &target)=0`](#virtual_Status_tensorflow_Env_RenameFile) * Renames file src to target. If target already exists, it will be replaced. * [`virtual uint64 tensorflow::Env::NowMicros()=0`](#virtual_uint64_tensorflow_Env_NowMicros) @@ -109,7 +109,7 @@ Deletes the specified directory. #### `virtual Status tensorflow::Env::GetFileSize(const string &fname, uint64 *file_size)=0` {#virtual_Status_tensorflow_Env_GetFileSize} -Stores the size of fname in *file_size. +Stores the size of `fname` in `*file_size`. diff --git a/tensorflow/g3doc/api_docs/cc/ClassEnvWrapper.md b/tensorflow/g3doc/api_docs/cc/ClassEnvWrapper.md index 54ffc6f5592..9ed2a970165 100644 --- a/tensorflow/g3doc/api_docs/cc/ClassEnvWrapper.md +++ b/tensorflow/g3doc/api_docs/cc/ClassEnvWrapper.md @@ -28,7 +28,7 @@ May be useful to clients who wish to override just part of the functionality of * [`Status tensorflow::EnvWrapper::DeleteDir(const string &d) override`](#Status_tensorflow_EnvWrapper_DeleteDir) * Deletes the specified directory. * [`Status tensorflow::EnvWrapper::GetFileSize(const string &f, uint64 *s) override`](#Status_tensorflow_EnvWrapper_GetFileSize) - * Stores the size of fname in *file_size. + * Stores the size of `fname` in `*file_size`. * [`Status tensorflow::EnvWrapper::RenameFile(const string &s, const string &t) override`](#Status_tensorflow_EnvWrapper_RenameFile) * Renames file src to target. If target already exists, it will be replaced. * [`uint64 tensorflow::EnvWrapper::NowMicros() override`](#uint64_tensorflow_EnvWrapper_NowMicros) @@ -114,7 +114,7 @@ Deletes the specified directory. #### `Status tensorflow::EnvWrapper::GetFileSize(const string &f, uint64 *s) override` {#Status_tensorflow_EnvWrapper_GetFileSize} -Stores the size of fname in *file_size. +Stores the size of `fname` in `*file_size`. diff --git a/tensorflow/g3doc/api_docs/cc/ClassRandomAccessFile.md b/tensorflow/g3doc/api_docs/cc/ClassRandomAccessFile.md index 6afcf1ee9fd..7655556fad6 100644 --- a/tensorflow/g3doc/api_docs/cc/ClassRandomAccessFile.md +++ b/tensorflow/g3doc/api_docs/cc/ClassRandomAccessFile.md @@ -9,7 +9,7 @@ A file abstraction for randomly reading the contents of a file. * [`tensorflow::RandomAccessFile::RandomAccessFile()`](#tensorflow_RandomAccessFile_RandomAccessFile) * [`virtual tensorflow::RandomAccessFile::~RandomAccessFile()`](#virtual_tensorflow_RandomAccessFile_RandomAccessFile) * [`virtual Status tensorflow::RandomAccessFile::Read(uint64 offset, size_t n, StringPiece *result, char *scratch) const =0`](#virtual_Status_tensorflow_RandomAccessFile_Read) - * Reads up to "n" bytes from the file starting at "offset". + * Reads up to `n` bytes from the file starting at `offset`. ##Member Details @@ -27,12 +27,12 @@ A file abstraction for randomly reading the contents of a file. #### `virtual Status tensorflow::RandomAccessFile::Read(uint64 offset, size_t n, StringPiece *result, char *scratch) const =0` {#virtual_Status_tensorflow_RandomAccessFile_Read} -Reads up to "n" bytes from the file starting at "offset". +Reads up to `n` bytes from the file starting at `offset`. -"scratch[0..n-1]" may be written by this routine. Sets "*result" to the data that was read (including if fewer than "n" bytes were successfully read). May set "*result" to point at data in "scratch[0..n-1]", so "scratch[0..n-1]" must be live when "*result" is used. +`scratch[0..n-1]` may be written by this routine. Sets `*result` to the data that was read (including if fewer than `n` bytes were successfully read). May set `*result` to point at data in `scratch[0..n-1]`, so `scratch[0..n-1]` must be live when `*result` is used. -On OK returned status: "n" bytes have been stored in "*result". On non-OK returned status: [0..n] bytes have been stored in "*result". +On OK returned status: `n` bytes have been stored in `*result`. On non-OK returned status: `[0..n]` bytes have been stored in `*result`. -Returns `OUT_OF_RANGE` if fewer than n bytes were stored in "*result" because of EOF. +Returns `OUT_OF_RANGE` if fewer than n bytes were stored in `*result` because of EOF. Safe for concurrent use by multiple threads. diff --git a/tensorflow/g3doc/api_docs/cc/index.md b/tensorflow/g3doc/api_docs/cc/index.md index ff57cc2445c..2bb24375cb9 100644 --- a/tensorflow/g3doc/api_docs/cc/index.md +++ b/tensorflow/g3doc/api_docs/cc/index.md @@ -4,7 +4,7 @@ TensorFlow's public C++ API includes only the API for executing graphs, as of version 0.5. To control the execution of a graph from C++: 1. Build the computation graph using the [Python API](../python/). -1. Use [tf.train.write_graph()](../python/train.md#write_graph) to +1. Use [`tf.train.write_graph()`](../python/train.md#write_graph) to write the graph to a file. 1. Load the graph using the C++ Session API. For example: diff --git a/tensorflow/g3doc/api_docs/images/DynamicPartition.png b/tensorflow/g3doc/api_docs/images/DynamicPartition.png deleted file mode 100644 index 56bee8a5a85..00000000000 Binary files a/tensorflow/g3doc/api_docs/images/DynamicPartition.png and /dev/null differ diff --git a/tensorflow/g3doc/api_docs/images/DynamicStitch.png b/tensorflow/g3doc/api_docs/images/DynamicStitch.png deleted file mode 100644 index 09d200c1caf..00000000000 Binary files a/tensorflow/g3doc/api_docs/images/DynamicStitch.png and /dev/null differ diff --git a/tensorflow/g3doc/api_docs/images/Gather.png b/tensorflow/g3doc/api_docs/images/Gather.png deleted file mode 100644 index 8fda752e4e2..00000000000 Binary files a/tensorflow/g3doc/api_docs/images/Gather.png and /dev/null differ diff --git a/tensorflow/g3doc/api_docs/images/ScatterAdd.png b/tensorflow/g3doc/api_docs/images/ScatterAdd.png deleted file mode 100644 index f8f5edf0bf0..00000000000 Binary files a/tensorflow/g3doc/api_docs/images/ScatterAdd.png and /dev/null differ diff --git a/tensorflow/g3doc/api_docs/images/ScatterSub.png b/tensorflow/g3doc/api_docs/images/ScatterSub.png deleted file mode 100644 index 8e3672ec8cb..00000000000 Binary files a/tensorflow/g3doc/api_docs/images/ScatterSub.png and /dev/null differ diff --git a/tensorflow/g3doc/api_docs/images/ScatterUpdate.png b/tensorflow/g3doc/api_docs/images/ScatterUpdate.png deleted file mode 100644 index 8c0d2c3ea5c..00000000000 Binary files a/tensorflow/g3doc/api_docs/images/ScatterUpdate.png and /dev/null differ diff --git a/tensorflow/g3doc/api_docs/images/SegmentMax.png b/tensorflow/g3doc/api_docs/images/SegmentMax.png deleted file mode 100644 index 81e724cd019..00000000000 Binary files a/tensorflow/g3doc/api_docs/images/SegmentMax.png and /dev/null differ diff --git a/tensorflow/g3doc/api_docs/images/SegmentMean.png b/tensorflow/g3doc/api_docs/images/SegmentMean.png deleted file mode 100644 index e619bd7e47d..00000000000 Binary files a/tensorflow/g3doc/api_docs/images/SegmentMean.png and /dev/null differ diff --git a/tensorflow/g3doc/api_docs/images/SegmentMin.png b/tensorflow/g3doc/api_docs/images/SegmentMin.png deleted file mode 100644 index 52e8ca671c4..00000000000 Binary files a/tensorflow/g3doc/api_docs/images/SegmentMin.png and /dev/null differ diff --git a/tensorflow/g3doc/api_docs/images/SegmentProd.png b/tensorflow/g3doc/api_docs/images/SegmentProd.png deleted file mode 100644 index b0f49f29786..00000000000 Binary files a/tensorflow/g3doc/api_docs/images/SegmentProd.png and /dev/null differ diff --git a/tensorflow/g3doc/api_docs/images/SegmentSum.png b/tensorflow/g3doc/api_docs/images/SegmentSum.png deleted file mode 100644 index 3b9ce88086d..00000000000 Binary files a/tensorflow/g3doc/api_docs/images/SegmentSum.png and /dev/null differ diff --git a/tensorflow/g3doc/api_docs/images/UnsortedSegmentSum.png b/tensorflow/g3doc/api_docs/images/UnsortedSegmentSum.png deleted file mode 100644 index a747765c993..00000000000 Binary files a/tensorflow/g3doc/api_docs/images/UnsortedSegmentSum.png and /dev/null differ diff --git a/tensorflow/g3doc/api_docs/python/array_ops.md b/tensorflow/g3doc/api_docs/python/array_ops.md index fe92f98ec6e..14d1e396233 100644 --- a/tensorflow/g3doc/api_docs/python/array_ops.md +++ b/tensorflow/g3doc/api_docs/python/array_ops.md @@ -32,7 +32,7 @@ results in a rounded value.) ##### Returns: A `Tensor` of type `out_type`. - A Tensor of the same shape as the input string_tensor. + A Tensor of the same shape as the input `string_tensor`. - - - @@ -817,9 +817,9 @@ tf.transpose(x) ==> [[1 4] [3 6]] # Equivalently -tf.transpose(x perm=[0, 1]) ==> [[1 4] - [2 5] - [3 6]] +tf.transpose(x, perm=[1, 0]) ==> [[1 4] + [2 5] + [3 6]] # 'perm' is more useful for n-dimensional tensors, for n > 2 # 'x' is [[[1 2 3] diff --git a/tensorflow/g3doc/api_docs/python/client.md b/tensorflow/g3doc/api_docs/python/client.md index b35da9a1ae9..3b8f1ceac7e 100644 --- a/tensorflow/g3doc/api_docs/python/client.md +++ b/tensorflow/g3doc/api_docs/python/client.md @@ -178,6 +178,7 @@ Calling this method frees all resources associated with the session. The graph that was launched in this session. + - - - #### `tf.Session.as_default()` {#Session.as_default} @@ -354,6 +355,7 @@ discover information about the op. The `Operation` that failed, or None. + - - - #### `tf.OpError.node_def` {#OpError.node_def} @@ -361,6 +363,7 @@ discover information about the op. The `NodeDef` proto representing the op that failed. + #### Other Methods - - - @@ -383,6 +386,7 @@ Creates a new `OpError` indicating that a particular op failed. The integer error code that describes the error. + - - - #### `tf.OpError.message` {#OpError.message} @@ -390,6 +394,7 @@ The integer error code that describes the error. The error message that describes the error. + - - - ### `class tf.errors.CancelledError` {#CancelledError} diff --git a/tensorflow/g3doc/api_docs/python/control_flow_ops.md b/tensorflow/g3doc/api_docs/python/control_flow_ops.md index 34f97bae5fd..dc69b2a6175 100644 --- a/tensorflow/g3doc/api_docs/python/control_flow_ops.md +++ b/tensorflow/g3doc/api_docs/python/control_flow_ops.md @@ -504,7 +504,7 @@ that are not a number (NaN) or infinity (Inf). Otherwise, passes `tensor` as-is. ### `tf.add_check_numerics_ops()` {#add_check_numerics_ops} -Connect a check_numerics to every floating point tensor. +Connect a `check_numerics` to every floating point tensor. `check_numerics` operations themselves are added for each `float` or `double` tensor in the graph. For all ops in the graph, the `check_numerics` op for diff --git a/tensorflow/g3doc/api_docs/python/framework.md b/tensorflow/g3doc/api_docs/python/framework.md index 31aa0eee12d..eea786647b5 100644 --- a/tensorflow/g3doc/api_docs/python/framework.md +++ b/tensorflow/g3doc/api_docs/python/framework.md @@ -119,7 +119,7 @@ This method is thread-safe. ##### Raises: -* <b>`ValueError`</b>: If the graph_def would be too large. +* <b>`ValueError`</b>: If the `graph_def` would be too large. - - - @@ -141,6 +141,7 @@ when using a [`QueueRunner`](../../api_docs/python/train.md#QueueRunner). True if this graph has been finalized. + - - - #### `tf.Graph.control_dependencies(control_inputs)` {#Graph.control_dependencies} @@ -511,6 +512,7 @@ Returns the default device. + - - - #### `tf.Graph.unique_name(name)` {#Graph.unique_name} @@ -545,6 +547,7 @@ TensorBoard. Returns a version number that increases as ops are added to the graph. + - - - #### `tf.Graph.create_op(op_type, inputs, dtypes, input_types=None, name=None, attrs=None, op_def=None, compute_shapes=True)` {#Graph.create_op} @@ -658,18 +661,21 @@ be executed by passing it to The full name of this operation. + - - - #### `tf.Operation.type` {#Operation.type} The type of the op (e.g. `"MatMul"`). + - - - #### `tf.Operation.inputs` {#Operation.inputs} The list of `Tensor` objects representing the data inputs of this op. + - - - #### `tf.Operation.control_inputs` {#Operation.control_inputs} @@ -686,12 +692,14 @@ in the correct order. A list of `Operation` objects. + - - - #### `tf.Operation.outputs` {#Operation.outputs} The list of `Tensor` objects representing the outputs of this op. + - - - #### `tf.Operation.device` {#Operation.device} @@ -703,6 +711,7 @@ The name of the device to which this op has been assigned, if any. The string name of the device to which this op has been assigned, or None if it has not been assigned to a device. + - - - #### `tf.Operation.graph` {#Operation.graph} @@ -710,6 +719,7 @@ The name of the device to which this op has been assigned, if any. The `Graph` that contains this operation. + - - - #### `tf.Operation.run(feed_dict=None, session=None)` {#Operation.run} @@ -762,6 +772,7 @@ Returns the value of the attr of this op with the given `name`. Returns the call stack from when this operation was constructed. + #### Other Methods - - - @@ -803,11 +814,11 @@ regular expression: * <b>`TypeError`</b>: if control inputs are not Operations or Tensors, - or if node_def is not a `NodeDef`, - or if g is not a `Graph`, - or if inputs are not tensors, - or if inputs and input_types are incompatible. -* <b>`ValueError`</b>: if the node_def name is not valid. + or if `node_def` is not a `NodeDef`, + or if `g` is not a `Graph`, + or if `inputs` are not tensors, + or if `inputs` and `input_types` are incompatible. +* <b>`ValueError`</b>: if the `node_def` name is not valid. - - - @@ -822,6 +833,7 @@ Returns a serialized `NodeDef` representation of this operation. [`NodeDef`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/graph.proto) protocol buffer. + - - - #### `tf.Operation.op_def` {#Operation.op_def} @@ -834,6 +846,7 @@ Returns the `OpDef` proto that represents the type of this op. [`OpDef`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/op_def.proto) protocol buffer. + - - - #### `tf.Operation.values()` {#Operation.values} @@ -889,30 +902,35 @@ result = sess.run(e) The `DType` of elements in this tensor. + - - - #### `tf.Tensor.name` {#Tensor.name} The string name of this tensor. + - - - #### `tf.Tensor.value_index` {#Tensor.value_index} The index of this tensor in the outputs of its `Operation`. + - - - #### `tf.Tensor.graph` {#Tensor.graph} The `Graph` that contains this tensor. + - - - #### `tf.Tensor.op` {#Tensor.op} The `Operation` that produces this tensor as an output. + - - - #### `tf.Tensor.consumers()` {#Tensor.consumers} @@ -1070,6 +1088,7 @@ The name of the device on which this tensor will be produced, or None. + ## Tensor types - - - @@ -1136,30 +1155,35 @@ DType(T).as_ref.is_compatible_with(DType(T).as_ref) == True Returns the string name for this `DType`. + - - - #### `tf.DType.base_dtype` {#DType.base_dtype} Returns a non-reference `DType` based on this `DType`. + - - - #### `tf.DType.is_ref_dtype` {#DType.is_ref_dtype} Returns `True` if this `DType` represents a reference type. + - - - #### `tf.DType.as_ref` {#DType.as_ref} Returns a reference `DType` based on this `DType`. + - - - #### `tf.DType.is_integer` {#DType.is_integer} Returns whether this is a (non-quantized) integer type. + - - - #### `tf.DType.is_quantized` {#DType.is_quantized} @@ -1167,12 +1191,14 @@ Returns whether this is a (non-quantized) integer type. Returns whether this is a quantized data type. + - - - #### `tf.DType.as_numpy_dtype` {#DType.as_numpy_dtype} Returns a `numpy.dtype` based on this `DType`. + - - - #### `tf.DType.as_datatype_enum` {#DType.as_datatype_enum} @@ -1180,6 +1206,7 @@ Returns a `numpy.dtype` based on this `DType`. Returns a `types_pb2.DataType` enum value based on this `DType`. + #### Other Methods - - - @@ -1188,8 +1215,8 @@ Returns a `types_pb2.DataType` enum value based on this `DType`. Creates a new `DataType`. NOTE(mrry): In normal circumstances, you should not need to -construct a DataType object directly. Instead, use the -types.as_dtype() function. +construct a `DataType` object directly. Instead, use the +`tf.as_dtype()` function. ##### Args: @@ -1208,6 +1235,7 @@ types.as_dtype() function. Returns whether this is a (real) floating point type. + - - - #### `tf.DType.max` {#DType.max} @@ -1219,6 +1247,7 @@ Returns the maximum representable value in this data type. * <b>`TypeError`</b>: if this is a non-numeric, unordered, or quantized type. + - - - #### `tf.DType.min` {#DType.min} @@ -1231,6 +1260,7 @@ Returns the minimum representable value in this data type. * <b>`TypeError`</b>: if this is a non-numeric, unordered, or quantized type. + - - - ### `tf.as_dtype(type_value)` {#as_dtype} @@ -1425,13 +1455,13 @@ protocol buffer, and extract individual objects in the `GraphDef` as ##### Returns: A list of `Operation` and/or `Tensor` objects from the imported graph, - corresponding to the names in `return_elements'. + corresponding to the names in `return_elements`. ##### Raises: * <b>`TypeError`</b>: If `graph_def` is not a `GraphDef` proto, - `input_map' is not a dictionary mapping strings to `Tensor` objects, + `input_map` is not a dictionary mapping strings to `Tensor` objects, or `return_elements` is not a list of strings. * <b>`ValueError`</b>: If `input_map`, or `return_elements` contains names that do not appear in `graph_def`, or `graph_def` is not well-formed (e.g. @@ -1613,7 +1643,7 @@ that defines the operation. #### `tf.RegisterShape.__init__(op_type)` {#RegisterShape.__init__} -Saves the "op_type" as the Operation type. +Saves the `op_type` as the `Operation` type. @@ -1694,12 +1724,14 @@ information for use with slicing. Returns the rank of this shape, or None if it is unspecified. + - - - #### `tf.TensorShape.dims` {#TensorShape.dims} Returns a list of Dimensions, or None if the shape is unspecified. + - - - #### `tf.TensorShape.as_list()` {#TensorShape.as_list} @@ -1916,7 +1948,7 @@ Creates a new TensorShape with the given dimensions. #### `tf.TensorShape.as_dimension_list()` {#TensorShape.as_dimension_list} -DEPRECATED: use as_list(). +DEPRECATED: use `as_list()`. - - - @@ -2014,6 +2046,7 @@ Dimensions are combined as follows: The value of this dimension, or None if it is unknown. + - - - ### `tf.op_scope(values, name, default_name)` {#op_scope} diff --git a/tensorflow/g3doc/api_docs/python/image.md b/tensorflow/g3doc/api_docs/python/image.md index a275c622a2b..5ca185edf4a 100644 --- a/tensorflow/g3doc/api_docs/python/image.md +++ b/tensorflow/g3doc/api_docs/python/image.md @@ -425,7 +425,7 @@ Crops an image to a specified bounding box. This op cuts a rectangular part out of `image`. The top-left corner of the returned image is at `offset_height, offset_width` in `image`, and its lower-right corner is at -`offset_height + target_height, offset_width + target_width'. +`offset_height + target_height, offset_width + target_width`. ##### Args: @@ -724,7 +724,7 @@ have modifications in the range `[-max_delta,max_delta]`. ##### Raises: -* <b>`ValueError`</b>: if max_delta is negative. +* <b>`ValueError`</b>: if `max_delta` is negative. @@ -831,3 +831,27 @@ Note that this implementation is limited: * <b>`ValueError`</b>: if the shape of 'image' is incompatible with this function. + +## Other Functions and Classes +- - - + +### `tf.image.resize_nearest_neighbor_grad(grads, size, name=None)` {#resize_nearest_neighbor_grad} + +Computes the gradient of nearest neighbor interpolation. + +##### Args: + + +* <b>`grads`</b>: A `Tensor`. Must be one of the following types: `uint8`, `int8`, `int32`, `float32`, `float64`. + 4-D with shape `[batch, height, width, channels]`. +* <b>`size`</b>: A 1-D int32 Tensor of 2 elements: `orig_height, orig_width`. The + original input size. +* <b>`name`</b>: A name for the operation (optional). + +##### Returns: + + A `Tensor`. Has the same type as `grads`. + 4-D with shape `[batch, orig_height, orig_width, channels]`. Gradients + with respect to the input image. + + diff --git a/tensorflow/g3doc/api_docs/python/index.md b/tensorflow/g3doc/api_docs/python/index.md index 14a4174315c..1211233fae9 100644 --- a/tensorflow/g3doc/api_docs/python/index.md +++ b/tensorflow/g3doc/api_docs/python/index.md @@ -215,6 +215,7 @@ * [`resize_image_with_crop_or_pad`](../../api_docs/python/image.md#resize_image_with_crop_or_pad) * [`resize_images`](../../api_docs/python/image.md#resize_images) * [`resize_nearest_neighbor`](../../api_docs/python/image.md#resize_nearest_neighbor) + * [`resize_nearest_neighbor_grad`](../../api_docs/python/image.md#resize_nearest_neighbor_grad) * [`transpose_image`](../../api_docs/python/image.md#transpose_image) * **[Sparse Tensors](../../api_docs/python/sparse_ops.md)**: diff --git a/tensorflow/g3doc/api_docs/python/io_ops.md b/tensorflow/g3doc/api_docs/python/io_ops.md index dbb358f45f6..0d4d52eea55 100644 --- a/tensorflow/g3doc/api_docs/python/io_ops.md +++ b/tensorflow/g3doc/api_docs/python/io_ops.md @@ -153,6 +153,7 @@ finished with the previous file). Op that implements the reader. + - - - #### `tf.ReaderBase.reset(name=None)` {#ReaderBase.reset} @@ -216,6 +217,7 @@ Unimplemented error. Whether the Reader implementation can serialize its state. + - - - ### `class tf.TextLineReader` {#TextLineReader} @@ -304,6 +306,7 @@ finished with the previous file). Op that implements the reader. + - - - #### `tf.TextLineReader.reset(name=None)` {#TextLineReader.reset} @@ -367,6 +370,7 @@ Unimplemented error. Whether the Reader implementation can serialize its state. + - - - ### `class tf.WholeFileReader` {#WholeFileReader} @@ -455,6 +459,7 @@ finished with the previous file). Op that implements the reader. + - - - #### `tf.WholeFileReader.reset(name=None)` {#WholeFileReader.reset} @@ -518,6 +523,7 @@ Unimplemented error. Whether the Reader implementation can serialize its state. + - - - ### `class tf.IdentityReader` {#IdentityReader} @@ -606,6 +612,7 @@ finished with the previous file). Op that implements the reader. + - - - #### `tf.IdentityReader.reset(name=None)` {#IdentityReader.reset} @@ -669,6 +676,7 @@ Unimplemented error. Whether the Reader implementation can serialize its state. + - - - ### `class tf.TFRecordReader` {#TFRecordReader} @@ -754,6 +762,7 @@ finished with the previous file). Op that implements the reader. + - - - #### `tf.TFRecordReader.reset(name=None)` {#TFRecordReader.reset} @@ -817,6 +826,7 @@ Unimplemented error. Whether the Reader implementation can serialize its state. + - - - ### `class tf.FixedLengthRecordReader` {#FixedLengthRecordReader} @@ -905,6 +915,7 @@ finished with the previous file). Op that implements the reader. + - - - #### `tf.FixedLengthRecordReader.reset(name=None)` {#FixedLengthRecordReader.reset} @@ -969,6 +980,7 @@ Whether the Reader implementation can serialize its state. + ## Converting TensorFlow provides several operations that you can use to convert various data @@ -1016,16 +1028,17 @@ Reinterpret the bytes of a string as a vector of numbers. All the elements must have the same length. * <b>`out_type`</b>: A `tf.DType` from: `tf.float32, tf.float64, tf.int32, tf.uint8, tf.int16, tf.int8, tf.int64`. * <b>`little_endian`</b>: An optional `bool`. Defaults to `True`. - Whether the input bytes are in little-endian order. - Ignored for out_types that are stored in a single byte like uint8. + Whether the input `bytes` are in little-endian order. + Ignored for `out_type` values that are stored in a single byte like + `uint8`. * <b>`name`</b>: A name for the operation (optional). ##### Returns: A `Tensor` of type `out_type`. - A Tensor with one more dimension than the input bytes. The + A Tensor with one more dimension than the input `bytes`. The added dimension will have size equal to the length of the elements - of bytes divided by the number of bytes to represent out_type. + of `bytes` divided by the number of bytes to represent `out_type`. @@ -1084,12 +1097,12 @@ serialized `Example`s are provided: ``` serialized = [ - features: - { feature: [ key: { "ft" value: float_list: { value: [1.0, 2.0] } } ] }, - features: - { feature: [] }, - features: - { feature: [ key: { "ft" value: float_list: { value: [3.0] } } ] } + features + { feature { key: "ft" value { float_list { value: [1.0, 2.0] } } } }, + features + { feature []}, + features + { feature { key: "ft" value { float_list { value: [3.0] } } } ] ``` @@ -1105,14 +1118,14 @@ Given two `Example` input protos in `serialized`: ``` [ - features: { - feature: { key: "kw" value: { bytes_list: { value: [ "knit", "big" ] } } } - feature: { key: "gps" value: { float_list: { value: [] } } } + features { + feature { key: "kw" value { bytes_list { value: [ "knit", "big" ] } } } + feature { key: "gps" value { float_list { value: [] } } } }, - features: { - feature: { key: "kw" value: { bytes_list: { value: [ "emmy" ] } } } - feature: { key: "dank" value: { int64_list: { value: [ 42 ] } } } - feature: { key: "gps" value: { } } + features { + feature { key: "kw" value { bytes_list { value: [ "emmy" ] } } } + feature { key: "dank" value { int64_list { value: [ 42 ] } } } + feature { key: "gps" value { } } } ] ``` @@ -1148,13 +1161,13 @@ For dense results in two serialized `Example`s: ``` [ - features: { - feature: { key: "age" value: { int64_list: { value: [ 0 ] } } } - feature: { key: "gender" value: { bytes_list: { value: [ "f" ] } } } + features { + feature { key: "age" value { int64_list { value: [ 0 ] } } } + feature { key: "gender" value { bytes_list { value: [ "f" ] } } } }, - features: { - feature: { key: "age" value: { int64_list: { value: [] } } } - feature: { key: "gender" value: { bytes_list: { value: [ "f" ] } } } + features { + feature { key: "age" value { int64_list { value: [] } } } + feature { key: "gender" value { bytes_list { value: [ "f" ] } } } } ] ``` @@ -1202,6 +1215,8 @@ And the expected output is: The keys of the dict must match the dense_keys of the feature. * <b>`dense_shapes`</b>: A list of tuples with the same length as `dense_keys`. The shape of the data for each dense feature referenced by `dense_keys`. + Required for any input tensors identified by dense_keys whose shapes are + anything other than [] or [1]. * <b>`name`</b>: A name for this operation (optional). ##### Returns: @@ -1229,7 +1244,7 @@ same as the shape given in `dense_shape`. For `SparseTensor`s, the first (batch) column of the indices matrix is removed (the indices matrix is a column vector), the values vector is unchanged, and -the first (batch_size) entry of the shape vector is removed (it is now a +the first (`batch_size`) entry of the shape vector is removed (it is now a single element vector). See also `parse_example`. @@ -1238,15 +1253,15 @@ See also `parse_example`. * <b>`serialized`</b>: A scalar string Tensor, a single serialized Example. - See parse_example documentation for more details. + See `parse_example` documentation for more details. * <b>`names`</b>: (Optional) A scalar string Tensor, the associated name. - See parse_example documentation for more details. -* <b>`sparse_keys`</b>: See parse_example documentation for more details. -* <b>`sparse_types`</b>: See parse_example documentation for more details. -* <b>`dense_keys`</b>: See parse_example documentation for more details. -* <b>`dense_types`</b>: See parse_example documentation for more details. -* <b>`dense_defaults`</b>: See parse_example documentation for more details. -* <b>`dense_shapes`</b>: See parse_example documentation for more details. + See `parse_example` documentation for more details. +* <b>`sparse_keys`</b>: See `parse_example` documentation for more details. +* <b>`sparse_types`</b>: See `parse_example` documentation for more details. +* <b>`dense_keys`</b>: See `parse_example` documentation for more details. +* <b>`dense_types`</b>: See `parse_example` documentation for more details. +* <b>`dense_defaults`</b>: See `parse_example` documentation for more details. +* <b>`dense_shapes`</b>: See `parse_example` documentation for more details. * <b>`name`</b>: A name for this operation (optional). ##### Returns: @@ -1450,12 +1465,38 @@ Constructs a queue object from a queue reference. The list of dtypes for each component of a queue element. + +- - - + +#### `tf.QueueBase.from_list(index, queues)` {#QueueBase.from_list} + +Create a queue using the queue reference from `queues[index]`. + +##### Args: + + +* <b>`index`</b>: An integer scalar tensor that determines the input that gets + selected. +* <b>`queues`</b>: A list of `QueueBase` objects. + +##### Returns: + + A `QueueBase` object. + +##### Raises: + + +* <b>`TypeError`</b>: When `queues` is not a list of `QueueBase` objects, + or when the data types of `queues` are not all the same. + + - - - #### `tf.QueueBase.name` {#QueueBase.name} The name of the underlying queue. + - - - #### `tf.QueueBase.queue_ref` {#QueueBase.queue_ref} @@ -1463,6 +1504,7 @@ The name of the underlying queue. The underlying queue reference. + - - - ### `class tf.FIFOQueue` {#FIFOQueue} @@ -1635,7 +1677,7 @@ Save the list of files matching pattern, so it is only computed once. ### `tf.train.limit_epochs(tensor, num_epochs=None, name=None)` {#limit_epochs} -Returns tensor num_epochs times and then raises an OutOfRange error. +Returns tensor `num_epochs` times and then raises an `OutOfRange` error. ##### Args: @@ -1647,7 +1689,7 @@ Returns tensor num_epochs times and then raises an OutOfRange error. ##### Returns: - tensor or OutOfRange. + tensor or `OutOfRange`. - - - @@ -1773,6 +1815,12 @@ first dimension. If an input tensor has shape `[*, x, y, z]`, the output will have shape `[batch_size, x, y, z]`. The `capacity` argument controls the how long the prefetching is allowed to grow the queues. +The returned operation is a dequeue operation and will throw +`tf.errors.OutOfRangeError` if the input queue is exhausted. If this +operation is feeding another input queue, its queue runner will catch +this exception, however, if this operation is used in your main thread +you are responsible for catching this yourself. + *N.B.:* You must ensure that either (i) the `shapes` argument is passed, or (ii) all of the tensors in `tensor_list` must have fully-defined shapes. `ValueError` will be raised if neither of @@ -1831,6 +1879,12 @@ same size in the first dimension. The slices of any input tensor The `capacity` argument controls the how long the prefetching is allowed to grow the queues. +The returned operation is a dequeue operation and will throw +`tf.errors.OutOfRangeError` if the input queue is exhausted. If this +operation is feeding another input queue, its queue runner will catch +this exception, however, if this operation is used in your main thread +you are responsible for catching this yourself. + *N.B.:* You must ensure that either (i) the `shapes` argument is passed, or (ii) all of the tensors in `tensor_list_list` must have fully-defined shapes. `ValueError` will be raised if neither of @@ -1886,6 +1940,12 @@ output will have shape `[batch_size, x, y, z]`. The `capacity` argument controls the how long the prefetching is allowed to grow the queues. +The returned operation is a dequeue operation and will throw +`tf.errors.OutOfRangeError` if the input queue is exhausted. If this +operation is feeding another input queue, its queue runner will catch +this exception, however, if this operation is used in your main thread +you are responsible for catching this yourself. + For example: ```python @@ -1961,10 +2021,11 @@ y, z]`, the output will have shape `[batch_size, x, y, z]`. The `capacity` argument controls the how long the prefetching is allowed to grow the queues. -*N.B.:* You must ensure that either (i) the `shapes` argument is -passed, or (ii) all of the tensors in `tensor_list_list` must have -fully-defined shapes. `ValueError` will be raised if neither of -these conditions holds. +The returned operation is a dequeue operation and will throw +`tf.errors.OutOfRangeError` if the input queue is exhausted. If this +operation is feeding another input queue, its queue runner will catch +this exception, however, if this operation is used in your main thread +you are responsible for catching this yourself. ##### Args: diff --git a/tensorflow/g3doc/api_docs/python/math_ops.md b/tensorflow/g3doc/api_docs/python/math_ops.md index f0fe5b200ff..43261de10bf 100644 --- a/tensorflow/g3doc/api_docs/python/math_ops.md +++ b/tensorflow/g3doc/api_docs/python/math_ops.md @@ -590,9 +590,9 @@ tf.transpose(x) ==> [[1 4] [3 6]] # Equivalently -tf.transpose(x perm=[0, 1]) ==> [[1 4] - [2 5] - [3 6]] +tf.transpose(x, perm=[1, 0]) ==> [[1 4] + [2 5] + [3 6]] # 'perm' is more useful for n-dimensional tensors, for n > 2 # 'x' is [[[1 2 3] @@ -1355,7 +1355,7 @@ that `segment_ids[j] == i`. ##### Returns: A `Tensor`. Has the same type as `data`. - Has same shape as data, except for dimension_0 which + Has same shape as data, except for dimension 0 which has size `k`, the number of segments. @@ -1389,7 +1389,7 @@ that `segment_ids[j] == i`. ##### Returns: A `Tensor`. Has the same type as `data`. - Has same shape as data, except for dimension_0 which + Has same shape as data, except for dimension 0 which has size `k`, the number of segments. @@ -1423,7 +1423,7 @@ that `segment_ids[j] == i`. ##### Returns: A `Tensor`. Has the same type as `data`. - Has same shape as data, except for dimension_0 which + Has same shape as data, except for dimension 0 which has size `k`, the number of segments. @@ -1456,7 +1456,7 @@ that `segment_ids[j] == i`. ##### Returns: A `Tensor`. Has the same type as `data`. - Has same shape as data, except for dimension_0 which + Has same shape as data, except for dimension 0 which has size `k`, the number of segments. @@ -1491,7 +1491,7 @@ values summed. ##### Returns: A `Tensor`. Has the same type as `data`. - Has same shape as data, except for dimension_0 which + Has same shape as data, except for dimension 0 which has size `k`, the number of segments. @@ -1533,7 +1533,7 @@ If the sum is empty for a given segment ID `i`, `output[i] = 0`. ##### Returns: A `Tensor`. Has the same type as `data`. - Has same shape as data, except for dimension_0 which + Has same shape as data, except for dimension 0 which has size `num_segments`. @@ -1549,7 +1549,7 @@ Segmentation](../../api_docs/python/math_ops.md#segmentation) for an explanation of segments. Like `SegmentSum`, but `segment_ids` can have rank less than `data`'s first -dimension, selecting a subset of dimension_0, specified by `indices`. +dimension, selecting a subset of dimension 0, specified by `indices`. For example: @@ -1587,7 +1587,7 @@ tf.segment_sum(c, tf.constant([0, 0, 1])) ##### Returns: A `Tensor`. Has the same type as `data`. - Has same shape as data, except for dimension_0 which + Has same shape as data, except for dimension 0 which has size `k`, the number of segments. @@ -1602,7 +1602,7 @@ Segmentation](../../api_docs/python/math_ops.md#segmentation) for an explanation of segments. Like `SegmentMean`, but `segment_ids` can have rank less than `data`'s first -dimension, selecting a subset of dimension_0, specified by `indices`. +dimension, selecting a subset of dimension 0, specified by `indices`. ##### Args: @@ -1617,7 +1617,7 @@ dimension, selecting a subset of dimension_0, specified by `indices`. ##### Returns: A `Tensor`. Has the same type as `data`. - Has same shape as data, except for dimension_0 which + Has same shape as data, except for dimension 0 which has size `k`, the number of segments. diff --git a/tensorflow/g3doc/api_docs/python/nn.md b/tensorflow/g3doc/api_docs/python/nn.md index 2fd07fa692e..068e5f2ec47 100644 --- a/tensorflow/g3doc/api_docs/python/nn.md +++ b/tensorflow/g3doc/api_docs/python/nn.md @@ -720,7 +720,7 @@ tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`. * <b>`params`</b>: A list of tensors with the same shape and type. -* <b>`ids`</b>: A `Tensor` with type `int32` containing the ids to be looked +* <b>`ids`</b>: A `Tensor` with type `int32` or `int64` containing the ids to be looked up in `params`. * <b>`name`</b>: A name for the operation (optional). @@ -850,17 +850,19 @@ with an otherwise unused class. ##### Args: -* <b>`weights`</b>: A `Tensor` of shape [num_classes, dim]. The class embeddings. -* <b>`biases`</b>: A `Tensor` of shape [num_classes]. The class biases. -* <b>`inputs`</b>: A `Tensor` of shape [batch_size, dim]. The forward +* <b>`weights`</b>: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor` + objects whose concatenation along dimension 0 has shape + [num_classes, dim]. The (possibly-sharded) class embeddings. +* <b>`biases`</b>: A `Tensor` of shape `[num_classes]`. The class biases. +* <b>`inputs`</b>: A `Tensor` of shape `[batch_size, dim]`. The forward activations of the input network. * <b>`labels`</b>: A `Tensor` of type `int64` and shape `[batch_size, - num_true]`. The target classes. + num_true]`. The target classes. * <b>`num_sampled`</b>: An `int`. The number of classes to randomly sample per batch. * <b>`num_classes`</b>: An `int`. The number of possible classes. * <b>`num_true`</b>: An `int`. The number of target classes per training example. -* <b>`sampled_values`</b>: a tuple of `(sampled_candidates, true_expected_count, - sampled_expected_count)` returned by a `*_candidate_sampler` function. +* <b>`sampled_values`</b>: a tuple of (`sampled_candidates`, `true_expected_count`, + `sampled_expected_count`) returned by a `*_candidate_sampler` function. (if None, we default to `log_uniform_candidate_sampler`) * <b>`remove_accidental_hits`</b>: A `bool`. Whether to remove "accidental hits" where a sampled class equals one of the target classes. If set to @@ -873,7 +875,7 @@ with an otherwise unused class. ##### Returns: - A batch_size 1-D tensor of per-example NCE losses. + A `batch_size` 1-D tensor of per-example NCE losses. - - - @@ -899,18 +901,20 @@ Also see Section 3 of http://arxiv.org/abs/1412.2007 for the math. ##### Args: -* <b>`weights`</b>: A `Tensor` of shape [num_classes, dim]. The class embeddings. -* <b>`biases`</b>: A `Tensor` of shape [num_classes]. The class biases. -* <b>`inputs`</b>: A `Tensor` of shape [batch_size, dim]. The forward +* <b>`weights`</b>: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor` + objects whose concatenation along dimension 0 has shape + [num_classes, dim]. The (possibly-sharded) class embeddings. +* <b>`biases`</b>: A `Tensor` of shape `[num_classes]`. The class biases. +* <b>`inputs`</b>: A `Tensor` of shape `[batch_size, dim]`. The forward activations of the input network. * <b>`labels`</b>: A `Tensor` of type `int64` and shape `[batch_size, - num_true]`. The target classes. Note that this format differs from - the `labels` argument of `nn.softmax_cross_entropy_with_logits`. + num_true]`. The target classes. Note that this format differs from + the `labels` argument of `nn.softmax_cross_entropy_with_logits`. * <b>`num_sampled`</b>: An `int`. The number of classes to randomly sample per batch. * <b>`num_classes`</b>: An `int`. The number of possible classes. * <b>`num_true`</b>: An `int`. The number of target classes per training example. -* <b>`sampled_values`</b>: a tuple of `(sampled_candidates, true_expected_count, - sampled_expected_count)` returned by a `*_candidate_sampler` function. +* <b>`sampled_values`</b>: a tuple of (`sampled_candidates`, `true_expected_count`, + `sampled_expected_count`) returned by a `*_candidate_sampler` function. (if None, we default to `log_uniform_candidate_sampler`) * <b>`remove_accidental_hits`</b>: A `bool`. whether to remove "accidental hits" where a sampled class equals one of the target classes. Default is @@ -919,7 +923,7 @@ Also see Section 3 of http://arxiv.org/abs/1412.2007 for the math. ##### Returns: - A batch_size 1-D tensor of per-example sampled softmax losses. + A `batch_size` 1-D tensor of per-example sampled softmax losses. @@ -1180,7 +1184,7 @@ compute them approximately. ### `tf.nn.compute_accidental_hits(true_classes, sampled_candidates, num_true, seed=None, name=None)` {#compute_accidental_hits} -Compute the ids of positions in sampled_candidates matching true_classes. +Compute the position ids in `sampled_candidates` matching `true_classes`. In Candidate Sampling, this operation facilitates virtually removing sampled classes which happen to match target classes. This is done diff --git a/tensorflow/g3doc/api_docs/python/sparse_ops.md b/tensorflow/g3doc/api_docs/python/sparse_ops.md index 13220c70fcf..4c7db4b10f5 100644 --- a/tensorflow/g3doc/api_docs/python/sparse_ops.md +++ b/tensorflow/g3doc/api_docs/python/sparse_ops.md @@ -91,6 +91,7 @@ The indices of non-zero values in the represented dense tensor. A 2-D Tensor of int64 with shape `[N, ndims]`, where `N` is the number of non-zero values in the tensor, and `ndims` is the rank. + - - - #### `tf.SparseTensor.values` {#SparseTensor.values} @@ -101,18 +102,21 @@ The non-zero values in the represented dense tensor. A 1-D Tensor of any data type. + - - - #### `tf.SparseTensor.dtype` {#SparseTensor.dtype} The `DType` of elements in this tensor. + - - - #### `tf.SparseTensor.shape` {#SparseTensor.shape} A 1-D Tensor of int64 representing the shape of the dense tensor. + - - - #### `tf.SparseTensor.graph` {#SparseTensor.graph} @@ -120,6 +124,7 @@ A 1-D Tensor of int64 representing the shape of the dense tensor. The `Graph` that contains the index, value, and shape tensors. + - - - ### `class tf.SparseTensorValue` {#SparseTensorValue} @@ -131,12 +136,14 @@ SparseTensorValue(indices, values, shape) Alias for field number 0 + - - - #### `tf.SparseTensorValue.shape` {#SparseTensorValue.shape} Alias for field number 2 + - - - #### `tf.SparseTensorValue.values` {#SparseTensorValue.values} @@ -145,6 +152,7 @@ Alias for field number 1 + ## Sparse to Dense Conversion - - - @@ -363,7 +371,7 @@ is during manual manipulation of the indices and values to add entries. Reordering does not affect the shape of the `SparseTensor`. -For example, if sp_input has shape `[4, 5]` and `indices` / `values`: +For example, if `sp_input` has shape `[4, 5]` and `indices` / `values`: [0, 3]: b [0, 1]: a diff --git a/tensorflow/g3doc/api_docs/python/state_ops.md b/tensorflow/g3doc/api_docs/python/state_ops.md index 41a44909632..cb9a090ebda 100644 --- a/tensorflow/g3doc/api_docs/python/state_ops.md +++ b/tensorflow/g3doc/api_docs/python/state_ops.md @@ -332,12 +332,14 @@ Properties. The name of this variable. + - - - #### `tf.Variable.dtype` {#Variable.dtype} The `DType` of this variable. + - - - #### `tf.Variable.get_shape()` {#Variable.get_shape} @@ -355,18 +357,21 @@ The `TensorShape` of this variable. The device of this variable. + - - - #### `tf.Variable.initializer` {#Variable.initializer} The initializer operation for this variable. + - - - #### `tf.Variable.graph` {#Variable.graph} The `Graph` of this variable. + - - - #### `tf.Variable.op` {#Variable.op} @@ -375,6 +380,7 @@ The `Operation` of this variable. + ## Variable helper functions TensorFlow provides a set of functions to help manage the set of variables @@ -587,43 +593,43 @@ saver = tf.train.Saver([v1, v2]) saver = tf.train.Saver({v.op.name: v for v in [v1, v2]}) ``` -The optional `reshape` argument, if True, allows restoring a variable from +The optional `reshape` argument, if `True`, allows restoring a variable from a save file where the variable had a different shape, but the same number of elements and type. This is useful if you have reshaped a variable and want to reload it from an older checkpoint. -The optional `sharded` argument, if True, instructs the saver to shard +The optional `sharded` argument, if `True`, instructs the saver to shard checkpoints per device. ##### Args: -* <b>`var_list`</b>: A list of Variables or a dictionary mapping names to - Variables. If None, defaults to the list of all variables. -* <b>`reshape`</b>: If True, allows restoring parameters from a checkpoint +* <b>`var_list`</b>: A list of `Variable` objects or a dictionary mapping names to + variables. If `None`, defaults to the list of all variables. +* <b>`reshape`</b>: If `True`, allows restoring parameters from a checkpoint where the variables have a different shape. -* <b>`sharded`</b>: If True, shard the checkpoints, one per device. +* <b>`sharded`</b>: If `True`, shard the checkpoints, one per device. * <b>`max_to_keep`</b>: maximum number of recent checkpoints to keep. Defaults to 10,000 hours. * <b>`keep_checkpoint_every_n_hours`</b>: How often to keep checkpoints. Defaults to 10,000 hours. * <b>`name`</b>: string. Optional name to use as a prefix when adding operations. -* <b>`restore_sequentially`</b>: A Bool, which if true, causes restore of different +* <b>`restore_sequentially`</b>: A `Bool`, which if true, causes restore of different variables to happen sequentially within each device. This can lower memory usage when restoring very large models. -* <b>`saver_def`</b>: Optional SaverDef proto to use instead of running the builder. - This is only useful for specialty code that wants to recreate a Saver - object for a previously built Graph that had a Saver. The saver_def - proto should be the one returned by the as_saver_def() call of the - Saver that was created for that Graph. -* <b>`builder`</b>: Optional SaverBuilder to use if a saver_def was not provided. - Defaults to BaseSaverBuilder(). +* <b>`saver_def`</b>: Optional `SaverDef` proto to use instead of running the + builder. This is only useful for specialty code that wants to recreate + a `Saver` object for a previously built `Graph` that had a `Saver`. + The `saver_def` proto should be the one returned by the + `as_saver_def()` call of the `Saver` that was created for that `Graph`. +* <b>`builder`</b>: Optional `SaverBuilder` to use if a `saver_def` was not provided. + Defaults to `BaseSaverBuilder()`. ##### Raises: * <b>`TypeError`</b>: If `var_list` is invalid. -* <b>`ValueError`</b>: If any of the keys or values in `var_list` is not unique. +* <b>`ValueError`</b>: If any of the keys or values in `var_list` are not unique. - - - @@ -647,7 +653,7 @@ path can be passed directly to a call to `restore()`. `sharded`, this is the prefix of the sharded checkpoint filename. * <b>`global_step`</b>: If provided the global step number is appended to `save_path` to create the checkpoint filename. The optional argument - can be a Tensor, a Tensor name or an integer. + can be a `Tensor`, a `Tensor` name or an integer. * <b>`latest_filename`</b>: Optional name for the protocol buffer file that will contains the list of most recent checkpoint filenames. That file, kept in the same directory as the checkpoint files, is automatically @@ -663,7 +669,7 @@ path can be passed directly to a call to `restore()`. ##### Raises: -* <b>`TypeError`</b>: If `sess` is not a Session. +* <b>`TypeError`</b>: If `sess` is not a `Session`. - - - @@ -683,7 +689,7 @@ The `save_path` argument is typically a value previously returned from a ##### Args: -* <b>`sess`</b>: A Session to use to restore the parameters. +* <b>`sess`</b>: A `Session` to use to restore the parameters. * <b>`save_path`</b>: Path where parameters were previously saved. @@ -702,21 +708,22 @@ You can pass any of the returned values to `restore()`. A list of checkpoint filenames, sorted from oldest to newest. + - - - #### `tf.train.Saver.set_last_checkpoints(last_checkpoints)` {#Saver.set_last_checkpoints} -Sets the list of not-yet-deleted checkpoint filenames. +Sets the list of old checkpoint filenames. ##### Args: -* <b>`last_checkpoints`</b>: a list of checkpoint filenames. +* <b>`last_checkpoints`</b>: A list of checkpoint filenames. ##### Raises: -* <b>`AssertionError`</b>: if the list of checkpoint filenames has already been set. +* <b>`AssertionError`</b>: If the list of checkpoint filenames has already been set. - - - @@ -748,7 +755,7 @@ Finds the filename of latest saved checkpoint file. ##### Returns: - The full path to the latest checkpoint or None if no checkpoint was found. + The full path to the latest checkpoint or `None` if no checkpoint was found. @@ -1319,12 +1326,14 @@ Creates an `IndexedSlices`. A `Tensor` containing the values of the slices. + - - - #### `tf.IndexedSlices.indices` {#IndexedSlices.indices} A 1-D `Tensor` containing the indices of the slices. + - - - #### `tf.IndexedSlices.dense_shape` {#IndexedSlices.dense_shape} @@ -1332,24 +1341,28 @@ A 1-D `Tensor` containing the indices of the slices. A 1-D `Tensor` containing the shape of the corresponding dense tensor. + - - - #### `tf.IndexedSlices.name` {#IndexedSlices.name} The name of this `IndexedSlices`. + - - - #### `tf.IndexedSlices.dtype` {#IndexedSlices.dtype} The `DType` of elements in this tensor. + - - - #### `tf.IndexedSlices.device` {#IndexedSlices.device} The name of the device on which `values` will be produced, or `None`. + - - - #### `tf.IndexedSlices.op` {#IndexedSlices.op} @@ -1357,3 +1370,4 @@ The name of the device on which `values` will be produced, or `None`. The `Operation` that produces `values` as an output. + diff --git a/tensorflow/g3doc/api_docs/python/train.md b/tensorflow/g3doc/api_docs/python/train.md index dcdc25cbbb6..6b36d913565 100644 --- a/tensorflow/g3doc/api_docs/python/train.md +++ b/tensorflow/g3doc/api_docs/python/train.md @@ -26,7 +26,7 @@ class directly, but instead instantiate one of its subclasses such as ### Usage -``` +```python # Create an optimizer with the desired parameters. opt = GradientDescentOptimizer(learning_rate=0.1) # Add Ops to the graph to minimize a cost by updating a list of variables. @@ -37,7 +37,7 @@ opt_op = opt.minimize(cost, <list of variables>) In the training program you will just have to run the returned Op. -``` +```python # Execute opt_op to do one step of training: opt_op.run() ``` @@ -54,7 +54,7 @@ before applying them you can instead use the optimizer in three steps: Example: -``` +```python # Create an optimizer. opt = GradientDescentOptimizer(learning_rate=0.1) @@ -96,49 +96,49 @@ This must be called by the constructors of subclasses. #### `tf.train.Optimizer.minimize(loss, global_step=None, var_list=None, gate_gradients=1, aggregation_method=None, name=None)` {#Optimizer.minimize} -Add operations to minimize 'loss' by updating 'var_list'. +Add operations to minimize `loss` by updating `var_list`. -This method simply combines calls compute_gradients() and -apply_gradients(). If you want to process the gradient before applying them -call compute_gradients() and apply_gradients() explicitly instead of using -this function. +This method simply combines calls `compute_gradients()` and +`apply_gradients()`. If you want to process the gradient before applying +them call `compute_gradients()` and `apply_gradients()` explicitly instead +of using this function. ##### Args: -* <b>`loss`</b>: A Tensor containing the value to minimize. -* <b>`global_step`</b>: Optional Variable to increment by one after the +* <b>`loss`</b>: A `Tensor` containing the value to minimize. +* <b>`global_step`</b>: Optional `Variable` to increment by one after the variables have been updated. -* <b>`var_list`</b>: Optional list of variables.Variable to update to minimize - 'loss'. Defaults to the list of variables collected in the graph - under the key GraphKeys.TRAINABLE_VARIABLES. +* <b>`var_list`</b>: Optional list of `Variable` objects to update to minimize + `loss`. Defaults to the list of variables collected in the graph + under the key `GraphKeys.TRAINABLE_VARIABLES`. * <b>`gate_gradients`</b>: How to gate the computation of gradients. Can be - GATE_NONE, GATE_OP, or GATE_GRAPH. + `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`. * <b>`aggregation_method`</b>: Specifies the method used to combine gradient terms. Valid values are defined in the class `AggregationMethod`. * <b>`name`</b>: Optional name for the returned operation. ##### Returns: - An Operation that updates the variables in 'var_list'. If 'global_step' - was not None, that operation also increments global_step. + An Operation that updates the variables in `var_list`. If `global_step` + was not `None`, that operation also increments `global_step`. ##### Raises: -* <b>`ValueError`</b>: if some of the variables are not variables.Variable objects. +* <b>`ValueError`</b>: if some of the variables are not `Variable` objects. - - - #### `tf.train.Optimizer.compute_gradients(loss, var_list=None, gate_gradients=1, aggregation_method=None)` {#Optimizer.compute_gradients} -Compute gradients of "loss" for the variables in "var_list". +Compute gradients of `loss` for the variables in `var_list`. -This is the first part of minimize(). It returns a list +This is the first part of `minimize()`. It returns a list of (gradient, variable) pairs where "gradient" is the gradient -for "variable". Note that "gradient" can be a Tensor, a -IndexedSlices, or None if there is no gradient for the +for "variable". Note that "gradient" can be a `Tensor`, an +`IndexedSlices`, or `None` if there is no gradient for the given variable. ##### Args: @@ -146,10 +146,10 @@ given variable. * <b>`loss`</b>: A Tensor containing the value to minimize. * <b>`var_list`</b>: Optional list of variables.Variable to update to minimize - "loss". Defaults to the list of variables collected in the graph - under the key GraphKey.TRAINABLE_VARIABLES. + `loss`. Defaults to the list of variables collected in the graph + under the key `GraphKey.TRAINABLE_VARIABLES`. * <b>`gate_gradients`</b>: How to gate the computation of gradients. Can be - GATE_NONE, GATE_OP, or GATE_GRAPH. + `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`. * <b>`aggregation_method`</b>: Specifies the method used to combine gradient terms. Valid values are defined in the class `AggregationMethod`. @@ -160,7 +160,7 @@ given variable. ##### Raises: -* <b>`TypeError`</b>: If var_list contains anything else than variables.Variable. +* <b>`TypeError`</b>: If `var_list` contains anything else than `Variable` objects. * <b>`ValueError`</b>: If some arguments are invalid. @@ -170,28 +170,28 @@ given variable. Apply gradients to variables. -This is the second part of minimize(). It returns an Operation that +This is the second part of `minimize()`. It returns an `Operation` that applies gradients. ##### Args: * <b>`grads_and_vars`</b>: List of (gradient, variable) pairs as returned by - compute_gradients(). -* <b>`global_step`</b>: Optional Variable to increment by one after the + `compute_gradients()`. +* <b>`global_step`</b>: Optional `Variable` to increment by one after the variables have been updated. * <b>`name`</b>: Optional name for the returned operation. Default to the - name passed to the Optimizer constructor. + name passed to the `Optimizer` constructor. ##### Returns: - An Operation that applies the specified gradients. If 'global_step' - was not None, that operation also increments global_step. + An `Operation` that applies the specified gradients. If `global_step` + was not None, that operation also increments `global_step`. ##### Raises: -* <b>`TypeError`</b>: if grads_and_vars is malformed. +* <b>`TypeError`</b>: if `grads_and_vars` is malformed. @@ -203,18 +203,18 @@ gradients. The possible values are: `GATE_NONE`, `GATE_OP`, and `GATE_GRAPH`. -<b>GATE_NONE</b>: Compute and apply gradients in parallel. This provides the -maximum parallelism in execution, at the cost of some non-reproducibility in -the results. For example the two gradients of MatMul depend on the input +<b>`GATE_NONE`</b>: Compute and apply gradients in parallel. This provides +the maximum parallelism in execution, at the cost of some non-reproducibility +in the results. For example the two gradients of `matmul` depend on the input values: With `GATE_NONE` one of the gradients could be applied to one of the inputs _before_ the other gradient is computed resulting in non-reproducible results. -<b>GATE_OP</b>: For each Op, make sure all gradients are computed before they -are used. This prevents race conditions for Ops that generate gradients for -multiple inputs where the gradients depend on the inputs. +<b>`GATE_OP`</b>: For each Op, make sure all gradients are computed before +they are used. This prevents race conditions for Ops that generate gradients +for multiple inputs where the gradients depend on the inputs. -<b>GATE_GRAPH</b>: Make sure all gradients for all variables are computed +<b>`GATE_GRAPH`</b>: Make sure all gradients for all variables are computed before any one of them is used. This provides the least parallelism but can be useful if you want to process all gradients before applying any of them. @@ -233,9 +233,9 @@ about the slots, etc. #### `tf.train.Optimizer.get_slot_names()` {#Optimizer.get_slot_names} -Return a list of the names of slots created by the Optimizer. +Return a list of the names of slots created by the `Optimizer`. -See get_slot(). +See `get_slot()`. ##### Returns: @@ -246,23 +246,24 @@ See get_slot(). #### `tf.train.Optimizer.get_slot(var, name)` {#Optimizer.get_slot} -Return a slot named "name" created for "var" by the Optimizer. +Return a slot named `name` created for `var` by the Optimizer. -Some Optimizer subclasses use additional variables. For example -Momentum and Adagrad use variables to accumulate updates. This method -gives access to these Variables if for some reason you need them. +Some `Optimizer` subclasses use additional variables. For example +`Momentum` and `Adagrad` use variables to accumulate updates. This method +gives access to these `Variable` objects if for some reason you need them. -Use get_slot_names() to get the list of slot names created by the Optimizer. +Use `get_slot_names()` to get the list of slot names created by the +`Optimizer`. ##### Args: -* <b>`var`</b>: A variable passed to minimize() or apply_gradients(). +* <b>`var`</b>: A variable passed to `minimize()` or `apply_gradients()`. * <b>`name`</b>: A string. ##### Returns: - The Variable for the slot if it was created, None otherwise. + The `Variable` for the slot if it was created, `None` otherwise. @@ -315,7 +316,7 @@ Construct a new Adagrad optimizer. ##### Raises: -* <b>`ValueError`</b>: If the initial_accumulator_value is invalid. +* <b>`ValueError`</b>: If the `initial_accumulator_value` is invalid. @@ -505,7 +506,7 @@ for y in `ys`. `grad_ys` is a list of tensors of the same length as `ys` that holds the initial gradients for each y in `ys`. When `grad_ys` is None, we fill in a tensor of '1's of the shape of y for each y in `ys`. A -user can provide their own initial 'grad_ys` to compute the +user can provide their own initial `grad_ys` to compute the derivatives using a different initial gradient for each y (e.g., if one wanted to weight the gradient differently for each value in each y). @@ -631,7 +632,7 @@ greater than `clip_value_max` are set to `clip_value_max`. Clips tensor values to a maximum L2-norm. Given a tensor `t`, and a maximum clip value `clip_norm`, this operation -normalizes `t` so that its L2-norm is less than or equal to `clip_norm'. +normalizes `t` so that its L2-norm is less than or equal to `clip_norm`. Specifically, if the L2-norm is already less than or equal to `clip_norm`, then `t` is not modified. If the L2-norm is greater than `clip_norm`, then this operation returns a tensor of the same type and shape as `t` with its @@ -664,7 +665,7 @@ Clips tensor values to a maximum average L2-norm. Given a tensor `t`, and a maximum clip value `clip_norm`, this operation normalizes `t` so that its average L2-norm is less than or equal to -`clip_norm'. Specifically, if the average L2-norm is already less than or +`clip_norm`. Specifically, if the average L2-norm is already less than or equal to `clip_norm`, then `t` is not modified. If the average L2-norm is greater than `clip_norm`, then this operation returns a tensor of the same type and shape as `t` with its values set to: @@ -700,18 +701,18 @@ and the global norm (`global_norm`) of all tensors in `t_list`. Optionally, if you've already computed the global norm for `t_list`, you can specify the global norm with `use_norm`. -To perform the clipping, the values t_list[i] are set to: +To perform the clipping, the values `t_list[i]` are set to: -`t_list[i] * clip_norm / max(global_norm, clip_norm)` + t_list[i] * clip_norm / max(global_norm, clip_norm) where: -`global_norm = sqrt(sum([l2norm(t)**2 for t in t_list]))` + global_norm = sqrt(sum([l2norm(t)**2 for t in t_list])) If `clip_norm > global_norm` then the entries in `t_list` remain as they are, otherwise they're all shrunk by the global ratio. -Any of the entries of `t_list` that are of type None are ignored. +Any of the entries of `t_list` that are of type `None` are ignored. This is the correct way to perform gradient clipping (for example, see R. Pascanu, T. Mikolov, and Y. Bengio, "On the difficulty of training @@ -1136,28 +1137,28 @@ Create a new Coordinator. Wait for threads to terminate. -Blocks until all 'threads' have terminated or request_stop() is called. +Blocks until all `threads` have terminated or `request_stop()` is called. -After the threads stop, if an 'exc_info' was passed to request_stop, that +After the threads stop, if an `exc_info` was passed to `request_stop`, that exception is re-reaised. -Grace period handling: When request_stop() is called, threads are given +Grace period handling: When `request_stop()` is called, threads are given 'stop_grace_period_secs' seconds to terminate. If any of them is still -alive after that period expires, a RuntimeError is raised. Note that if -an 'exc_info' was passed to request_stop() then it is raised instead of -that RuntimeError. +alive after that period expires, a `RuntimeError` is raised. Note that if +an `exc_info` was passed to `request_stop()` then it is raised instead of +that `RuntimeError`. ##### Args: -* <b>`threads`</b>: List threading.Threads. The started threads to join. +* <b>`threads`</b>: List of `threading.Threads`. The started threads to join. * <b>`stop_grace_period_secs`</b>: Number of seconds given to threads to stop after - request_stop() has been called. + `request_stop()` has been called. ##### Raises: -* <b>`RuntimeError`</b>: If any thread is still alive after request_stop() +* <b>`RuntimeError`</b>: If any thread is still alive after `request_stop()` is called and the grace period expires. @@ -1167,14 +1168,14 @@ that RuntimeError. Request that the threads stop. -After this is called, calls to should_stop() will return True. +After this is called, calls to `should_stop()` will return `True`. ##### Args: -* <b>`ex`</b>: Optional Exception, or Python 'exc_info' tuple as returned by - sys.exc_info(). If this is the first call to request_stop() the - corresponding exception is recorded and re-raised from join(). +* <b>`ex`</b>: Optional `Exception`, or Python `exc_info` tuple as returned by + `sys.exc_info()`. If this is the first call to `request_stop()` the + corresponding exception is recorded and re-raised from `join()`. - - - @@ -1306,6 +1307,7 @@ depending on whether or not a `Coordinator` was passed to was captured. (No exceptions are captured when using a Coordinator.) + - - - ### `tf.train.add_queue_runner(qr, collection='queue_runners')` {#add_queue_runner} @@ -1768,7 +1770,7 @@ global_step: 10 Writes a graph proto on disk. -The graph is written as a binary proto unless as_text is `True`. +The graph is written as a binary proto unless `as_text` is `True`. ```python v = tf.Variable(0, name='my_variable') diff --git a/tensorflow/g3doc/extras/candidate_sampling.pdf b/tensorflow/g3doc/extras/candidate_sampling.pdf deleted file mode 100644 index d4a1c64c932..00000000000 Binary files a/tensorflow/g3doc/extras/candidate_sampling.pdf and /dev/null differ diff --git a/tensorflow/g3doc/extras/tensorflow-whitepaper2015.pdf b/tensorflow/g3doc/extras/tensorflow-whitepaper2015.pdf deleted file mode 100644 index 8f977a7f0d1..00000000000 Binary files a/tensorflow/g3doc/extras/tensorflow-whitepaper2015.pdf and /dev/null differ diff --git a/tensorflow/g3doc/get_started/blue_pill.png b/tensorflow/g3doc/get_started/blue_pill.png deleted file mode 100644 index bd3c90f742f..00000000000 Binary files a/tensorflow/g3doc/get_started/blue_pill.png and /dev/null differ diff --git a/tensorflow/g3doc/get_started/index.md b/tensorflow/g3doc/get_started/index.md index f8095df40d3..e9e0bcd0cbc 100644 --- a/tensorflow/g3doc/get_started/index.md +++ b/tensorflow/g3doc/get_started/index.md @@ -2,45 +2,47 @@ Let's get you up and running with TensorFlow! -But before we even get started, let's give you a sneak peek at what TensorFlow -code looks like in the Python API, just so you have a sense of where we're +But before we even get started, let's peek at what TensorFlow +code looks like in the Python API, so you have a sense of where we're headed. -Here's a little Python program that makes up some data in three dimensions, and -then fits a plane to it. +Here's a little Python program that makes up some data in two dimensions, and +then fits a line to it. ```python import tensorflow as tf import numpy as np -# Make 100 phony data points in NumPy. -x_data = np.float32(np.random.rand(2, 100)) # Random input -y_data = np.dot([0.100, 0.200], x_data) + 0.300 +# Create 100 phony x, y data points in NumPy, y = x * 0.1 + 0.3 +x_data = np.random.rand(100).astype("float32") +y_data = x_data * 0.1 + 0.3 -# Construct a linear model. +# Try to find values for W and b that compute y_data = W * x_data + b +# (We know that W should be 0.1 and b 0.3, but Tensorflow will +# figure that out for us.) +W = tf.Variable(tf.random_uniform([1], -1.0, 1.0)) b = tf.Variable(tf.zeros([1])) -W = tf.Variable(tf.random_uniform([1, 2], -1.0, 1.0)) -y = tf.matmul(W, x_data) + b +y = W * x_data + b -# Minimize the squared errors. +# Minimize the mean squared errors. loss = tf.reduce_mean(tf.square(y - y_data)) optimizer = tf.train.GradientDescentOptimizer(0.5) train = optimizer.minimize(loss) -# For initializing the variables. +# Before starting, initialize the variables. We will 'run' this first. init = tf.initialize_all_variables() -# Launch the graph +# Launch the graph. sess = tf.Session() sess.run(init) -# Fit the plane. -for step in xrange(0, 201): +# Fit the line. +for step in xrange(201): sess.run(train) if step % 20 == 0: print step, sess.run(W), sess.run(b) -# Learns best fit is W: [[0.100 0.200]], b: [0.300] +# Learns best fit is W: [0.1], b: [0.3] ``` The first part of this code builds the data flow graph. TensorFlow does not diff --git a/tensorflow/g3doc/get_started/os_setup.md b/tensorflow/g3doc/get_started/os_setup.md index cf970cacc4d..fde583e47f7 100644 --- a/tensorflow/g3doc/get_started/os_setup.md +++ b/tensorflow/g3doc/get_started/os_setup.md @@ -1,133 +1,189 @@ # Download and Setup -You can install TensorFlow using our provided binary packages or from source. +You can install TensorFlow either from our provided binary packages or from the +github source. -## Binary Installation +## Requirements -The TensorFlow Python API currently requires Python 2.7: we are -[working](https://github.com/tensorflow/tensorflow/issues/1) on adding support -for Python 3. +The TensorFlow Python API currently requires Python 2.7. We are +[adding support for Python 3](https://github.com/tensorflow/tensorflow/issues/1). -The simplest way to install TensorFlow is using -[pip](https://pypi.python.org/pypi/pip) for both Linux and Mac. +The GPU version (Linux only) currently requires the Cuda Toolkit 7.0 and CUDNN +6.5 V2. Please see [Cuda installation](#install_cuda). + +## Overview + +We support different ways to install TensorFlow: + +* [Pip install](#pip_install): Install TensorFlow on your machine, possibly + upgrading previously installed Python packages. May impact existing + Python programs on your machine. +* [Virtualenv install](#virtualenv_install): Install TensorFlow in its own + directory, not impacting any existing Python programs on your machine. +* [Docker install](#docker_install): Run TensorFlow in a Docker container + isolated from all other programs on your machine. + +If you are familiar with Pip, Virtualenv, or Docker, please feel free to adapt +the instructions to your particular needs. The names of the pip and Docker +images are listed in the corresponding installation sections. If you encounter installation errors, see -[common problems](#common_install_problems) for some solutions. To simplify -installation, please consider using our virtualenv-based instructions -[here](#virtualenv_install). +[common problems](#common_install_problems) for some solutions. -### Ubuntu/Linux 64-bit +## Pip Installation {#pip_install} + +[Pip](https://en.wikipedia.org/wiki/Pip_(package_manager)) is a package +management system used to install and manage software packages written in +Python. + +The packages that will be installed or upgraded during the pip install are listed in the +[REQUIRED_PACKAGES section of setup.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/pip_package/setup.py) + +Install pip if not already installed: ```bash -# For CPU-only version -$ pip install https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.5.0-cp27-none-linux_x86_64.whl +# Ubuntu/Linux 64-bit +$ sudo apt-get install python-pip python-dev -# For GPU-enabled version (only install this version if you have the CUDA sdk installed) -$ pip install https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.5.0-cp27-none-linux_x86_64.whl +# Mac OS X +$ sudo easy_install pip ``` -### Mac OS X - -On OS X, we recommend installing [homebrew](http://brew.sh) and `brew install -python` before proceeding, or installing TensorFlow within [virtualenv](#virtualenv_install). +Install TensorFlow: ```bash -# Only CPU-version is available at the moment. -$ pip install https://storage.googleapis.com/tensorflow/mac/tensorflow-0.5.0-py2-none-any.whl +# Ubuntu/Linux 64-bit, CPU only: +$ sudo pip install --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.5.0-cp27-none-linux_x86_64.whl + +# Ubuntu/Linux 64-bit, GPU enabled: +$ sudo pip install --upgrade https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.5.0-cp27-none-linux_x86_64.whl + +# Mac OS X, CPU only: +$ sudo easy_install --upgrade six +$ sudo pip install --upgrade https://storage.googleapis.com/tensorflow/mac/tensorflow-0.5.0-py2-none-any.whl ``` -## Docker-based installation +You can now [test your installation](#test_install). -We also support running TensorFlow via [Docker](http://docker.com/), which lets -you avoid worrying about setting up dependencies. +## Virtualenv installation {#virtualenv_install} -First, [install Docker](http://docs.docker.com/engine/installation/). Once -Docker is up and running, you can start a container with one command: +[Virtualenv](http://docs.python-guide.org/en/latest/dev/virtualenvs/) is a tool +to keep the dependencies required by different Python projects in separate +places. The Virtualenv installation of TensorFlow will not override +pre-existing version of the Python packages needed by TensorFlow. + +With [Virtualenv](https://pypi.python.org/pypi/virtualenv) the installation is +as follows: + +* Install pip and Virtualenv. +* Create a Virtualenv environment. +* Activate the Virtualenv environment and install TensorFlow in it. +* After the install you will activate the Virtualenv environment each time you + want to use TensorFlow. + +Install pip and Virtualenv: + +```bash +# Ubuntu/Linux 64-bit +$ sudo apt-get install python-pip python-dev python-virtualenv + +# Mac OS X +$ sudo easy_install pip +$ sudo pip install --upgrade virtualenv +``` + +Create a Virtualenv environment in the directory `~/tensorflow`: + +```bash +$ virtualenv --system-site-packages ~/tensorflow +``` + +Activate the environment and use pip to install TensorFlow inside it: + +```bash +$ source ~/tensorflow/bin/activate # If using bash +$ source ~/tensorflow/bin/activate.csh # If using csh +(tensorflow)$ # Your prompt should change + +# Ubuntu/Linux 64-bit, CPU only: +(tensorflow)$ pip install --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.5.0-cp27-none-linux_x86_64.whl + +# Ubuntu/Linux 64-bit, GPU enabled: +(tensorflow)$ pip install --upgrade https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.5.0-cp27-none-linux_x86_64.whl + +# Mac OS X, CPU only: +(tensorflow)$ pip install --upgrade https://storage.googleapis.com/tensorflow/mac/tensorflow-0.5.0-py2-none-any.whl +``` + +With the Virtualenv environment activated, you can now +[test your installation](#test_install). + +```bash +# When you are done using TensorFlow, deactivate the environment. +(tensorflow)$ deactivate + +$ # Your prompt should change back +``` + +To use TensorFlow later you will have to activate the Virtualenv environment again: + +```bash +$ source ~/tensorflow/bin/activate # If using bash. +$ source ~/tensorflow/bin/activate.csh # If using csh. +(tensorflow)$ # Your prompt should change. +# Run Python programs that use TensorFlow. +... +# When you are done using TensorFlow, deactivate the environment. +(tensorflow)$ deactivate +``` + +## Docker installation {#docker_install} + +[Docker](http://docker.com/) is a system to build self contained versions of a +Linux operating system running on your machine. When you install and run +TensorFlow via Docker it completely isolates the installation from pre-existing +packages on your machine. + +We provide 2 Docker images: + +* `b.gcr.io/tensorflow/tensorflow`: TensorFlow CPU binary image. +* `b.gcr.io/tensorflow/tensorflow-full`: CPU Binary image plus source code. + +With Docker the installation is as follows: + +* Install Docker on your machine. +* Launch a Docker container with the TensorFlow image. The image + gets downloaded automatically on first launch. + +See [installing Docker](http://docs.docker.com/engine/installation/) for instructions +on installing Docker on your machine. + +Also create a [Docker +group](http://docs.docker.com/engine/installation/ubuntulinux/#create-a-docker-group) +to allow launching containers without `sudo`. + +After Docker is installed, launch a Docker container with the TensorFlow binary +image as follows. ```bash $ docker run -it b.gcr.io/tensorflow/tensorflow ``` -This will start a container with TensorFlow and all its dependencies already -installed. +Within the Docker container, you can now [test your installation](#test_install). -### Additional images - -The default Docker image above contains just a minimal set of libraries for -getting up and running with TensorFlow. We also have the following container, -which you can use in the `docker run` command above: - -* `b.gcr.io/tensorflow/tensorflow-full`: Contains a complete TensorFlow source - installation, including all utilities needed to build and run TensorFlow. This - makes it easy to experiment directly with the source, without needing to - install any of the dependencies described above. - -## VirtualEnv-based installation {#virtualenv_install} - -We recommend using [virtualenv](https://pypi.python.org/pypi/virtualenv) to -create an isolated container and install TensorFlow in that container -- it is -optional but makes verifying installation issues easier. - -First, install all required tools: +You can alternatively launch the TensorFlow source image, for example if you want +to experiment directly with the source. ```bash -# On Linux: -$ sudo apt-get install python-pip python-dev python-virtualenv - -# On Mac: -$ sudo easy_install pip # If pip is not already installed -$ sudo pip install --upgrade virtualenv +$ docker run -it b.gcr.io/tensorflow/tensorflow-full ``` -Next, set up a new virtualenv environment. To set it up in the -directory `~/tensorflow`, run: +## Test the TensorFlow installation {#test_install} -```bash -$ virtualenv --system-site-packages ~/tensorflow -$ cd ~/tensorflow -``` +### (Optional, Linux) Enable GPU Support -Then activate the virtualenv: - -```bash -$ source bin/activate # If using bash -$ source bin/activate.csh # If using csh -(tensorflow)$ # Your prompt should change -``` - -Inside the virtualenv, install TensorFlow: - -```bash -# For CPU-only linux x86_64 version -(tensorflow)$ pip install --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.5.0-cp27-none-linux_x86_64.whl - -# For GPU-enabled linux x86_64 version -(tensorflow)$ pip install --upgrade https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.5.0-cp27-none-linux_x86_64.whl - -# For Mac CPU-only version -(tensorflow)$ pip install --upgrade https://storage.googleapis.com/tensorflow/mac/tensorflow-0.5.0-py2-none-any.whl -``` - -Make sure you have downloaded the source code for TensorFlow, and then you can -then run an example TensorFlow program like: - -```bash -(tensorflow)$ cd tensorflow/models/image/mnist -(tensorflow)$ python convolutional.py - -# When you are done using TensorFlow: -(tensorflow)$ deactivate # Deactivate the virtualenv - -$ # Your prompt should change back -``` - -## Try your first TensorFlow program - -### (Optional) Enable GPU Support - -If you installed the GPU-enabled TensorFlow pip binary, you must have the -correct versions of the CUDA SDK and CUDNN installed on your -system. Please see [the CUDA installation instructions](#install_cuda). +If you installed the GPU version of TensorFlow, you must also install the Cuda +Toolkit 7.0 and CUDNN 6.5 V2. Please see [Cuda installation](#install_cuda). You also need to set the `LD_LIBRARY_PATH` and `CUDA_HOME` environment variables. Consider adding the commands below to your `~/.bash_profile`. These @@ -138,13 +194,15 @@ export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/cuda/lib64" export CUDA_HOME=/usr/local/cuda ``` -### Run TensorFlow +### Run TensorFlow from the Command Line -Open a python terminal: +See [common problems](#common_install_problems) if some error happens. + +Open a terminal and type the following: ```bash $ python - +... >>> import tensorflow as tf >>> hello = tf.constant('Hello, TensorFlow!') >>> sess = tf.Session() @@ -152,10 +210,43 @@ $ python Hello, TensorFlow! >>> a = tf.constant(10) >>> b = tf.constant(32) ->>> print sess.run(a+b) +>>> print sess.run(a + b) 42 >>> +``` +### Run a TensorFlow demo model + +All TensorFlow packages, including the demo models, are installed in the Python library. +The exact location of the Python library depends on your system, but is usually one of: + +```bash +/usr/local/lib/python2.7/dist-packages/tensorflow +/usr/local/lib/python2.7/site-packages/tensorflow +``` + +You can find out the directory with the following command: + +```bash +$ python -c 'import site; print "\n".join(site.getsitepackages())' +``` + +The simple demo model for classifying handwritten digits from the MNIST dataset +is in the sub-directory `models/image/mnist/convolutional.py`. You can run it from the command +line as follows: + +```bash +# Using 'python -m' to find the program in the python search path: +$ python -m tensorflow.models.image.mnist.convolutional +Extracting data/train-images-idx3-ubyte.gz +Extracting data/train-labels-idx1-ubyte.gz +Extracting data/t10k-images-idx3-ubyte.gz +Extracting data/t10k-labels-idx1-ubyte.gz +...etc... + +# You can alternatively pass the path to the model program file to the python interpreter. +$ python /usr/local/lib/python2.7/dist-packages/tensorflow/models/image/mnist/convolutional.py +... ``` ## Installing from sources {#source} @@ -427,39 +518,39 @@ SyntaxError: invalid syntax Solution: make sure you are using Python 2.7. -### On MacOSX +### Mac OS X: ImportError: No module named copyreg - -If you encounter: +On Mac OS X, you may encounter the following when importing tensorflow. ```python -import six.moves.copyreg as copyreg - +>>> import tensorflow as tf +... ImportError: No module named copyreg ``` -Solution: TensorFlow depends on protobuf, which requires `six-1.10.0`. Apple's -default python environment has `six-1.4.1` and may be difficult to upgrade. -There are several ways to fix this: +Solution: TensorFlow depends on protobuf, which requires the Python package +`six-1.10.0`. Apple's default Python installation only provides `six-1.4.1`. -1. Upgrade the system-wide copy of `six`: +You can resolve the issue in one of the following ways: - ```bash - sudo easy_install -U six - ``` +* Upgrade the Python installation with the current version `six`: -2. Install a separate copy of python via homebrew: +```bash +$ sudo easy_install -U six +``` - ```bash - brew install python - ``` +* Install TensorFlow with a separate Python library: -3. Build or use TensorFlow - [within `virtualenv`](#virtualenv_install). + * Using [Virtualenv](#virtualenv_install). + * Using [Docker](#docker_install). +* Install a separate copy of Python via [Homebrew](http://brew.sh/) or +[MacPorts](https://www.macports.org/) and re-install TensorFlow in that +copy of Python. +# Mac OS X: TypeError: `__init__()` got an unexpected keyword argument 'syntax' -If you encounter: +On Mac OS X, you may encounter the following when importing tensorflow. ``` >>> import tensorflow as tf @@ -480,5 +571,5 @@ The best current solution is to make sure older versions of protobuf are not installed, such as: ```bash -brew reinstall --devel protobuf +$ pip install --upgrade protobuf ``` diff --git a/tensorflow/g3doc/get_started/red_pill.png b/tensorflow/g3doc/get_started/red_pill.png deleted file mode 100644 index a01ca86247f..00000000000 Binary files a/tensorflow/g3doc/get_started/red_pill.png and /dev/null differ diff --git a/tensorflow/g3doc/how_tos/adding_an_op/index.md b/tensorflow/g3doc/how_tos/adding_an_op/index.md index d492a9ac635..150ad8d6e68 100644 --- a/tensorflow/g3doc/how_tos/adding_an_op/index.md +++ b/tensorflow/g3doc/how_tos/adding_an_op/index.md @@ -117,7 +117,7 @@ Python op wrappers are created automatically in `bazel-genfiles/tensorflow/python/ops/gen_user_ops.py` for all ops placed in the [`tensorflow/core/user_ops`][user_ops] directory when you build Tensorflow. -> Note: The generated function will be given a snake_case name (to comply with +> Note: The generated function will be given a snake\_case name (to comply with > [PEP8](https://www.python.org/dev/peps/pep-0008/)). So if your op is named > `ZeroOut` in the C++ files, the python function will be called `zero_out`. @@ -294,7 +294,7 @@ which can then be used in the `Compute` method: <code class="lang-c++"><pre> void Compute(OpKernelContext\* context) override { // ... -<br/> <b>// Check that preserve_index is in range +<br/> <b>// Check that preserve\_index is in range OP\_REQUIRES(context, preserve\_index_ < input.dimension(0), errors::InvalidArgument("preserve\_index out of range"));<br/> </b>// Set all the elements of the output tensor to 0 @@ -314,7 +314,7 @@ which can then be used in the `Compute` method: > <code class="lang-c++"><pre> > REGISTER\_OP("ZeroOut") > <b>.Attr("preserve\_index: int = 0")</b> -> .Input("to_zero: int32") +> .Input("to\_zero: int32") > .Output("zeroed: int32"); > </pre></code> @@ -449,7 +449,7 @@ in addition to `int32`s, your Op registration might look like: <code class="lang-c++"><pre> REGISTER\_OP("ZeroOut") <b>.Attr("T: {float, int32}")</b> - .Input("to_zero: <b>T</b>") + .Input("to\_zero: <b>T</b>") .Output("zeroed: <b>T</b>"); </pre></code> @@ -457,7 +457,7 @@ Your Op registration now specifies that the input's type must be `float`, or `int32`, and that its output will be the same type, since both have type `T`. > A note on naming:{#naming} Inputs, outputs, and attrs generally should be -> given snake_case names. The one exception is attrs that are used as the type +> given snake\_case names. The one exception is attrs that are used as the type > of an input or in the type of an input. Those attrs can be inferred when the > op is added to the graph and so don't appear in the op's function. For > example, this last definition of ZeroOut will generate a Python function that @@ -560,7 +560,7 @@ REGISTER\_KERNEL\_BUILDER( > <code class="lang-c++"><pre> > REGISTER\_OP("ZeroOut") > <b>.Attr("T: {float, int32} = DT_INT32")</b> -> .Input("to_zero: T") +> .Input("to\_zero: T") > .Output("zeroed: T") > </pre></code> @@ -569,7 +569,7 @@ Lets say you wanted to add more types, say `double`: <code class="lang-c++"><pre> REGISTER\_OP("ZeroOut") <b>.Attr("T: {float, <b>double,</b> int32}")</b> - .Input("to_zero: <b>T</b>") + .Input("to\_zero: <b>T</b>") .Output("zeroed: <b>T</b>"); </pre></code> @@ -888,7 +888,7 @@ There are several ways to preserve backwards-compatibility. type into a list of varying types). The full list of safe and unsafe changes can be found in -[tensorflow/core/framework/op_compatibility_test.cc](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/op_compatibility_test.cc). +[`tensorflow/core/framework/op_compatibility_test.cc`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/op_compatibility_test.cc). If you cannot make your change to an operation backwards compatible, then create a new operation with a new name with the new semantics. diff --git a/tensorflow/g3doc/how_tos/graph_viz/colorby_device.png b/tensorflow/g3doc/how_tos/graph_viz/colorby_device.png deleted file mode 100644 index 4a32ea6d614..00000000000 Binary files a/tensorflow/g3doc/how_tos/graph_viz/colorby_device.png and /dev/null differ diff --git a/tensorflow/g3doc/how_tos/graph_viz/colorby_structure.png b/tensorflow/g3doc/how_tos/graph_viz/colorby_structure.png deleted file mode 100644 index 71db521266f..00000000000 Binary files a/tensorflow/g3doc/how_tos/graph_viz/colorby_structure.png and /dev/null differ diff --git a/tensorflow/g3doc/how_tos/graph_viz/constant.png b/tensorflow/g3doc/how_tos/graph_viz/constant.png deleted file mode 100644 index fb1cab6da8d..00000000000 Binary files a/tensorflow/g3doc/how_tos/graph_viz/constant.png and /dev/null differ diff --git a/tensorflow/g3doc/how_tos/graph_viz/control_edge.png b/tensorflow/g3doc/how_tos/graph_viz/control_edge.png deleted file mode 100644 index ee63e15c8fa..00000000000 Binary files a/tensorflow/g3doc/how_tos/graph_viz/control_edge.png and /dev/null differ diff --git a/tensorflow/g3doc/how_tos/graph_viz/conv_1.png b/tensorflow/g3doc/how_tos/graph_viz/conv_1.png deleted file mode 100644 index fc839267603..00000000000 Binary files a/tensorflow/g3doc/how_tos/graph_viz/conv_1.png and /dev/null differ diff --git a/tensorflow/g3doc/how_tos/graph_viz/dataflow_edge.png b/tensorflow/g3doc/how_tos/graph_viz/dataflow_edge.png deleted file mode 100644 index 4204a72ff94..00000000000 Binary files a/tensorflow/g3doc/how_tos/graph_viz/dataflow_edge.png and /dev/null differ diff --git a/tensorflow/g3doc/how_tos/graph_viz/graph_vis_animation.gif b/tensorflow/g3doc/how_tos/graph_viz/graph_vis_animation.gif deleted file mode 100644 index 556383270c4..00000000000 Binary files a/tensorflow/g3doc/how_tos/graph_viz/graph_vis_animation.gif and /dev/null differ diff --git a/tensorflow/g3doc/how_tos/graph_viz/horizontal_stack.png b/tensorflow/g3doc/how_tos/graph_viz/horizontal_stack.png deleted file mode 100644 index ec742fa6aa5..00000000000 Binary files a/tensorflow/g3doc/how_tos/graph_viz/horizontal_stack.png and /dev/null differ diff --git a/tensorflow/g3doc/how_tos/graph_viz/infocard.png b/tensorflow/g3doc/how_tos/graph_viz/infocard.png deleted file mode 100644 index ab48e2ee3dc..00000000000 Binary files a/tensorflow/g3doc/how_tos/graph_viz/infocard.png and /dev/null differ diff --git a/tensorflow/g3doc/how_tos/graph_viz/infocard_op.png b/tensorflow/g3doc/how_tos/graph_viz/infocard_op.png deleted file mode 100644 index 2f8e695452a..00000000000 Binary files a/tensorflow/g3doc/how_tos/graph_viz/infocard_op.png and /dev/null differ diff --git a/tensorflow/g3doc/how_tos/graph_viz/namespace_node.png b/tensorflow/g3doc/how_tos/graph_viz/namespace_node.png deleted file mode 100644 index aa92fa1efc3..00000000000 Binary files a/tensorflow/g3doc/how_tos/graph_viz/namespace_node.png and /dev/null differ diff --git a/tensorflow/g3doc/how_tos/graph_viz/op_node.png b/tensorflow/g3doc/how_tos/graph_viz/op_node.png deleted file mode 100644 index 4547e6ebc3b..00000000000 Binary files a/tensorflow/g3doc/how_tos/graph_viz/op_node.png and /dev/null differ diff --git a/tensorflow/g3doc/how_tos/graph_viz/pool1_collapsed.png b/tensorflow/g3doc/how_tos/graph_viz/pool1_collapsed.png deleted file mode 100644 index afc31df504c..00000000000 Binary files a/tensorflow/g3doc/how_tos/graph_viz/pool1_collapsed.png and /dev/null differ diff --git a/tensorflow/g3doc/how_tos/graph_viz/pool1_expanded.png b/tensorflow/g3doc/how_tos/graph_viz/pool1_expanded.png deleted file mode 100644 index 9538003af26..00000000000 Binary files a/tensorflow/g3doc/how_tos/graph_viz/pool1_expanded.png and /dev/null differ diff --git a/tensorflow/g3doc/how_tos/graph_viz/reference_edge.png b/tensorflow/g3doc/how_tos/graph_viz/reference_edge.png deleted file mode 100644 index c456363ed2a..00000000000 Binary files a/tensorflow/g3doc/how_tos/graph_viz/reference_edge.png and /dev/null differ diff --git a/tensorflow/g3doc/how_tos/graph_viz/save.png b/tensorflow/g3doc/how_tos/graph_viz/save.png deleted file mode 100644 index 6a0278b3f5b..00000000000 Binary files a/tensorflow/g3doc/how_tos/graph_viz/save.png and /dev/null differ diff --git a/tensorflow/g3doc/how_tos/graph_viz/series.png b/tensorflow/g3doc/how_tos/graph_viz/series.png deleted file mode 100644 index 1fa8af8b9af..00000000000 Binary files a/tensorflow/g3doc/how_tos/graph_viz/series.png and /dev/null differ diff --git a/tensorflow/g3doc/how_tos/graph_viz/series_expanded.png b/tensorflow/g3doc/how_tos/graph_viz/series_expanded.png deleted file mode 100644 index 3bb2d12409b..00000000000 Binary files a/tensorflow/g3doc/how_tos/graph_viz/series_expanded.png and /dev/null differ diff --git a/tensorflow/g3doc/how_tos/graph_viz/summary.png b/tensorflow/g3doc/how_tos/graph_viz/summary.png deleted file mode 100644 index 239a7806f26..00000000000 Binary files a/tensorflow/g3doc/how_tos/graph_viz/summary.png and /dev/null differ diff --git a/tensorflow/g3doc/how_tos/graph_viz/vertical_stack.png b/tensorflow/g3doc/how_tos/graph_viz/vertical_stack.png deleted file mode 100644 index d22ecfe88ec..00000000000 Binary files a/tensorflow/g3doc/how_tos/graph_viz/vertical_stack.png and /dev/null differ diff --git a/tensorflow/g3doc/how_tos/reading_data/AnimatedFileQueues.gif b/tensorflow/g3doc/how_tos/reading_data/AnimatedFileQueues.gif deleted file mode 100644 index 623651b0080..00000000000 Binary files a/tensorflow/g3doc/how_tos/reading_data/AnimatedFileQueues.gif and /dev/null differ diff --git a/tensorflow/g3doc/how_tos/reading_data/index.md b/tensorflow/g3doc/how_tos/reading_data/index.md index 089ee4e34d1..907652952a9 100644 --- a/tensorflow/g3doc/how_tos/reading_data/index.md +++ b/tensorflow/g3doc/how_tos/reading_data/index.md @@ -143,7 +143,8 @@ and described in Another approach is to convert whatever data you have into a supported format. This approach makes it easier to mix and match data sets and network -architectures. The recommended format for TensorFlow is a TFRecords file +architectures. The recommended format for TensorFlow is a +[TFRecords file](../../api_docs/python/python_io.md#tfrecords-format-details) containing [`tf.train.Example` protocol buffers](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/example/example.proto) (which contain diff --git a/tensorflow/g3doc/how_tos/summaries_and_tensorboard/index.md b/tensorflow/g3doc/how_tos/summaries_and_tensorboard/index.md index 3a8e4787102..fdec071aeec 100644 --- a/tensorflow/g3doc/how_tos/summaries_and_tensorboard/index.md +++ b/tensorflow/g3doc/how_tos/summaries_and_tensorboard/index.md @@ -1,6 +1,6 @@ # TensorBoard: Visualizing Learning -The computations you'll use TensorBoard for - like training a massive +The computations you'll use TensorFlow for - like training a massive deep neural network - can be complex and confusing. To make it easier to understand, debug, and optimize TensorFlow programs, we've included a suite of visualization tools called TensorBoard. You can use TensorBoard to visualize @@ -59,7 +59,7 @@ Also, the `SummaryWriter` can optionally take a `GraphDef` in its constructor. If it receives one, then TensorBoard will visualize your graph as well. Now that you've modified your graph and have a `SummaryWriter`, you're ready to -start runing your network! If you want, you could run the merged summary op +start running your network! If you want, you could run the merged summary op every single step, and record a ton of training data. That's likely to be more data than you need, though. Instead, consider running the merged summary op every hundred steps or so, as in the following code example. diff --git a/tensorflow/g3doc/how_tos/summaries_and_tensorboard/mnist_tensorboard.png b/tensorflow/g3doc/how_tos/summaries_and_tensorboard/mnist_tensorboard.png deleted file mode 100644 index 35ff1057ad1..00000000000 Binary files a/tensorflow/g3doc/how_tos/summaries_and_tensorboard/mnist_tensorboard.png and /dev/null differ diff --git a/tensorflow/g3doc/how_tos/threading_and_queues/IncremeterFifoQueue.gif b/tensorflow/g3doc/how_tos/threading_and_queues/IncremeterFifoQueue.gif deleted file mode 100644 index c7b5fa03ac8..00000000000 Binary files a/tensorflow/g3doc/how_tos/threading_and_queues/IncremeterFifoQueue.gif and /dev/null differ diff --git a/tensorflow/g3doc/images/baseball_network.png b/tensorflow/g3doc/images/baseball_network.png deleted file mode 100644 index 06362dbc739..00000000000 Binary files a/tensorflow/g3doc/images/baseball_network.png and /dev/null differ diff --git a/tensorflow/g3doc/images/getting_started.png b/tensorflow/g3doc/images/getting_started.png deleted file mode 100644 index 09c26fa646f..00000000000 Binary files a/tensorflow/g3doc/images/getting_started.png and /dev/null differ diff --git a/tensorflow/g3doc/images/results.png b/tensorflow/g3doc/images/results.png deleted file mode 100644 index f9c59260e1c..00000000000 Binary files a/tensorflow/g3doc/images/results.png and /dev/null differ diff --git a/tensorflow/g3doc/images/scatterplot.png b/tensorflow/g3doc/images/scatterplot.png deleted file mode 100644 index b19f872ecfd..00000000000 Binary files a/tensorflow/g3doc/images/scatterplot.png and /dev/null differ diff --git a/tensorflow/g3doc/images/tensors_flowing.gif b/tensorflow/g3doc/images/tensors_flowing.gif deleted file mode 100644 index a12ac8da2a2..00000000000 Binary files a/tensorflow/g3doc/images/tensors_flowing.gif and /dev/null differ diff --git a/tensorflow/g3doc/images/tf_logo.png b/tensorflow/g3doc/images/tf_logo.png deleted file mode 100644 index 0c9c7e8f96e..00000000000 Binary files a/tensorflow/g3doc/images/tf_logo.png and /dev/null differ diff --git a/tensorflow/g3doc/images/tf_logo_transp.png b/tensorflow/g3doc/images/tf_logo_transp.png deleted file mode 100644 index a8190c9af6a..00000000000 Binary files a/tensorflow/g3doc/images/tf_logo_transp.png and /dev/null differ diff --git a/tensorflow/g3doc/tutorials/deep_cnn/Parallelism.png b/tensorflow/g3doc/tutorials/deep_cnn/Parallelism.png deleted file mode 100644 index e61594fcebe..00000000000 Binary files a/tensorflow/g3doc/tutorials/deep_cnn/Parallelism.png and /dev/null differ diff --git a/tensorflow/g3doc/tutorials/deep_cnn/cifar_activations.png b/tensorflow/g3doc/tutorials/deep_cnn/cifar_activations.png deleted file mode 100644 index 1905a924097..00000000000 Binary files a/tensorflow/g3doc/tutorials/deep_cnn/cifar_activations.png and /dev/null differ diff --git a/tensorflow/g3doc/tutorials/deep_cnn/cifar_graph.png b/tensorflow/g3doc/tutorials/deep_cnn/cifar_graph.png deleted file mode 100644 index 12bcdb05d20..00000000000 Binary files a/tensorflow/g3doc/tutorials/deep_cnn/cifar_graph.png and /dev/null differ diff --git a/tensorflow/g3doc/tutorials/deep_cnn/cifar_image_summary.png b/tensorflow/g3doc/tutorials/deep_cnn/cifar_image_summary.png deleted file mode 100644 index ba89981b193..00000000000 Binary files a/tensorflow/g3doc/tutorials/deep_cnn/cifar_image_summary.png and /dev/null differ diff --git a/tensorflow/g3doc/tutorials/deep_cnn/cifar_loss.png b/tensorflow/g3doc/tutorials/deep_cnn/cifar_loss.png deleted file mode 100644 index fbfc45309bc..00000000000 Binary files a/tensorflow/g3doc/tutorials/deep_cnn/cifar_loss.png and /dev/null differ diff --git a/tensorflow/g3doc/tutorials/deep_cnn/cifar_lr_decay.png b/tensorflow/g3doc/tutorials/deep_cnn/cifar_lr_decay.png deleted file mode 100644 index 2d72ceaef64..00000000000 Binary files a/tensorflow/g3doc/tutorials/deep_cnn/cifar_lr_decay.png and /dev/null differ diff --git a/tensorflow/g3doc/tutorials/deep_cnn/cifar_samples.png b/tensorflow/g3doc/tutorials/deep_cnn/cifar_samples.png deleted file mode 100644 index d68b01d8679..00000000000 Binary files a/tensorflow/g3doc/tutorials/deep_cnn/cifar_samples.png and /dev/null differ diff --git a/tensorflow/g3doc/tutorials/deep_cnn/cifar_sparsity.png b/tensorflow/g3doc/tutorials/deep_cnn/cifar_sparsity.png deleted file mode 100644 index 311c1c37b3c..00000000000 Binary files a/tensorflow/g3doc/tutorials/deep_cnn/cifar_sparsity.png and /dev/null differ diff --git a/tensorflow/g3doc/tutorials/deep_cnn/cifar_var_histograms.png b/tensorflow/g3doc/tutorials/deep_cnn/cifar_var_histograms.png deleted file mode 100644 index 647c863f2f6..00000000000 Binary files a/tensorflow/g3doc/tutorials/deep_cnn/cifar_var_histograms.png and /dev/null differ diff --git a/tensorflow/g3doc/tutorials/deep_cnn/index.md b/tensorflow/g3doc/tutorials/deep_cnn/index.md index 51989305490..59d106680eb 100644 --- a/tensorflow/g3doc/tutorials/deep_cnn/index.md +++ b/tensorflow/g3doc/tutorials/deep_cnn/index.md @@ -197,7 +197,7 @@ For regularization, we also apply the usual variables. The objective function for the model is the sum of the cross entropy loss and all these weight decay terms, as returned by the `loss()` function. -We visualize it in TensorBoard with a [scalar_summary](../../api_docs/python/train.md#scalar_summary): +We visualize it in TensorBoard with a [`scalar_summary`](../../api_docs/python/train.md#scalar_summary):  diff --git a/tensorflow/g3doc/tutorials/mandelbrot/mandelbrot_output.jpg b/tensorflow/g3doc/tutorials/mandelbrot/mandelbrot_output.jpg deleted file mode 100644 index 8e261d44a84..00000000000 Binary files a/tensorflow/g3doc/tutorials/mandelbrot/mandelbrot_output.jpg and /dev/null differ diff --git a/tensorflow/g3doc/tutorials/mnist/beginners/img/MNIST-Matrix.png b/tensorflow/g3doc/tutorials/mnist/beginners/img/MNIST-Matrix.png deleted file mode 100644 index cfde10da30d..00000000000 Binary files a/tensorflow/g3doc/tutorials/mnist/beginners/img/MNIST-Matrix.png and /dev/null differ diff --git a/tensorflow/g3doc/tutorials/mnist/beginners/img/MNIST.png b/tensorflow/g3doc/tutorials/mnist/beginners/img/MNIST.png deleted file mode 100644 index 77537bb0b9b..00000000000 Binary files a/tensorflow/g3doc/tutorials/mnist/beginners/img/MNIST.png and /dev/null differ diff --git a/tensorflow/g3doc/tutorials/mnist/beginners/img/mnist-train-xs.png b/tensorflow/g3doc/tutorials/mnist/beginners/img/mnist-train-xs.png deleted file mode 100644 index 5a523e00e1c..00000000000 Binary files a/tensorflow/g3doc/tutorials/mnist/beginners/img/mnist-train-xs.png and /dev/null differ diff --git a/tensorflow/g3doc/tutorials/mnist/beginners/img/mnist-train-ys.png b/tensorflow/g3doc/tutorials/mnist/beginners/img/mnist-train-ys.png deleted file mode 100644 index 2cd16a29910..00000000000 Binary files a/tensorflow/g3doc/tutorials/mnist/beginners/img/mnist-train-ys.png and /dev/null differ diff --git a/tensorflow/g3doc/tutorials/mnist/beginners/img/softmax-regression-scalarequation.png b/tensorflow/g3doc/tutorials/mnist/beginners/img/softmax-regression-scalarequation.png deleted file mode 100644 index 7d5c5938ec9..00000000000 Binary files a/tensorflow/g3doc/tutorials/mnist/beginners/img/softmax-regression-scalarequation.png and /dev/null differ diff --git a/tensorflow/g3doc/tutorials/mnist/beginners/img/softmax-regression-scalargraph.png b/tensorflow/g3doc/tutorials/mnist/beginners/img/softmax-regression-scalargraph.png deleted file mode 100644 index 6538d98167c..00000000000 Binary files a/tensorflow/g3doc/tutorials/mnist/beginners/img/softmax-regression-scalargraph.png and /dev/null differ diff --git a/tensorflow/g3doc/tutorials/mnist/beginners/img/softmax-regression-vectorequation.png b/tensorflow/g3doc/tutorials/mnist/beginners/img/softmax-regression-vectorequation.png deleted file mode 100644 index fd0b6a4df3a..00000000000 Binary files a/tensorflow/g3doc/tutorials/mnist/beginners/img/softmax-regression-vectorequation.png and /dev/null differ diff --git a/tensorflow/g3doc/tutorials/mnist/beginners/img/softmax-weights.png b/tensorflow/g3doc/tutorials/mnist/beginners/img/softmax-weights.png deleted file mode 100644 index 60b584ea9a0..00000000000 Binary files a/tensorflow/g3doc/tutorials/mnist/beginners/img/softmax-weights.png and /dev/null differ diff --git a/tensorflow/g3doc/tutorials/mnist/tf/mnist_digits.png b/tensorflow/g3doc/tutorials/mnist/tf/mnist_digits.png deleted file mode 100644 index 1c094f8a56b..00000000000 Binary files a/tensorflow/g3doc/tutorials/mnist/tf/mnist_digits.png and /dev/null differ diff --git a/tensorflow/g3doc/tutorials/mnist/tf/mnist_subgraph.png b/tensorflow/g3doc/tutorials/mnist/tf/mnist_subgraph.png deleted file mode 100644 index 958f5e70789..00000000000 Binary files a/tensorflow/g3doc/tutorials/mnist/tf/mnist_subgraph.png and /dev/null differ diff --git a/tensorflow/g3doc/tutorials/mnist/tf/mnist_tensorboard.png b/tensorflow/g3doc/tutorials/mnist/tf/mnist_tensorboard.png deleted file mode 100644 index 35ff1057ad1..00000000000 Binary files a/tensorflow/g3doc/tutorials/mnist/tf/mnist_tensorboard.png and /dev/null differ diff --git a/tensorflow/g3doc/tutorials/pdes/pde_output_1.jpg b/tensorflow/g3doc/tutorials/pdes/pde_output_1.jpg deleted file mode 100755 index 97954effc00..00000000000 Binary files a/tensorflow/g3doc/tutorials/pdes/pde_output_1.jpg and /dev/null differ diff --git a/tensorflow/g3doc/tutorials/pdes/pde_output_2.jpg b/tensorflow/g3doc/tutorials/pdes/pde_output_2.jpg deleted file mode 100755 index 8cd8cf02b51..00000000000 Binary files a/tensorflow/g3doc/tutorials/pdes/pde_output_2.jpg and /dev/null differ diff --git a/tensorflow/g3doc/tutorials/seq2seq/attention_seq2seq.png b/tensorflow/g3doc/tutorials/seq2seq/attention_seq2seq.png deleted file mode 100644 index 7cd590e7b21..00000000000 Binary files a/tensorflow/g3doc/tutorials/seq2seq/attention_seq2seq.png and /dev/null differ diff --git a/tensorflow/g3doc/tutorials/seq2seq/basic_seq2seq.png b/tensorflow/g3doc/tutorials/seq2seq/basic_seq2seq.png deleted file mode 100644 index 4e59bcfc4d7..00000000000 Binary files a/tensorflow/g3doc/tutorials/seq2seq/basic_seq2seq.png and /dev/null differ diff --git a/tensorflow/g3doc/tutorials/word2vec/img/audio-image-text.png b/tensorflow/g3doc/tutorials/word2vec/img/audio-image-text.png deleted file mode 100644 index ec58403937f..00000000000 Binary files a/tensorflow/g3doc/tutorials/word2vec/img/audio-image-text.png and /dev/null differ diff --git a/tensorflow/g3doc/tutorials/word2vec/img/linear-relationships.png b/tensorflow/g3doc/tutorials/word2vec/img/linear-relationships.png deleted file mode 100644 index 6f69925613d..00000000000 Binary files a/tensorflow/g3doc/tutorials/word2vec/img/linear-relationships.png and /dev/null differ diff --git a/tensorflow/g3doc/tutorials/word2vec/img/nce-nplm.png b/tensorflow/g3doc/tutorials/word2vec/img/nce-nplm.png deleted file mode 100644 index ad2090435c9..00000000000 Binary files a/tensorflow/g3doc/tutorials/word2vec/img/nce-nplm.png and /dev/null differ diff --git a/tensorflow/g3doc/tutorials/word2vec/img/softmax-nplm.png b/tensorflow/g3doc/tutorials/word2vec/img/softmax-nplm.png deleted file mode 100644 index 6c372269036..00000000000 Binary files a/tensorflow/g3doc/tutorials/word2vec/img/softmax-nplm.png and /dev/null differ diff --git a/tensorflow/g3doc/tutorials/word2vec/img/tsne.png b/tensorflow/g3doc/tutorials/word2vec/img/tsne.png deleted file mode 100644 index 848d7affe01..00000000000 Binary files a/tensorflow/g3doc/tutorials/word2vec/img/tsne.png and /dev/null differ diff --git a/tensorflow/g3doc/tutorials/word2vec/word2vec_basic.py b/tensorflow/g3doc/tutorials/word2vec/word2vec_basic.py index b378026e1b6..e04e86a1005 100644 --- a/tensorflow/g3doc/tutorials/word2vec/word2vec_basic.py +++ b/tensorflow/g3doc/tutorials/word2vec/word2vec_basic.py @@ -234,7 +234,7 @@ try: tsne = TSNE(perplexity=30, n_components=2, init='pca', n_iter=5000) plot_only = 500 low_dim_embs = tsne.fit_transform(final_embeddings[:plot_only,:]) - labels = list(dictionary.keys())[:plot_only] + labels = [reverse_dictionary[i] for i in xrange(plot_only)] plot_with_labels(low_dim_embs, labels) except ImportError: diff --git a/tensorflow/models/embedding/word2vec.py b/tensorflow/models/embedding/word2vec.py index 7011051aed0..7b103974fc1 100644 --- a/tensorflow/models/embedding/word2vec.py +++ b/tensorflow/models/embedding/word2vec.py @@ -45,6 +45,7 @@ import numpy as np import tensorflow as tf from tensorflow.models.embedding import gen_word2vec as word2vec +from tensorflow.python.util import compat flags = tf.app.flags @@ -178,11 +179,11 @@ class Word2Vec(object): """ questions = [] questions_skipped = 0 - with open(self._options.eval_data) as analogy_f: + with open(self._options.eval_data, "rb") as analogy_f: for line in analogy_f: - if line.startswith(":"): # Skip comments. + if line.startswith(b":"): # Skip comments. continue - words = line.strip().lower().split(" ") + words = line.strip().lower().split(b" ") ids = [self._word2id.get(w.strip()) for w in words] if None in ids or len(ids) != 4: questions_skipped += 1 @@ -380,7 +381,8 @@ class Word2Vec(object): opts = self._options with open(os.path.join(opts.save_path, "vocab.txt"), "w") as f: for i in xrange(opts.vocab_size): - f.write(opts.vocab_words[i] + " " + str(opts.vocab_counts[i]) + "\n") + f.write("%s %d\n" % (compat.as_text(opts.vocab_words[i]), + opts.vocab_counts[i])) def _train_thread_body(self): initial_epoch, = self._session.run([self._epoch]) diff --git a/tensorflow/models/embedding/word2vec_optimized.py b/tensorflow/models/embedding/word2vec_optimized.py index a7b3338dd51..db43e77ea72 100644 --- a/tensorflow/models/embedding/word2vec_optimized.py +++ b/tensorflow/models/embedding/word2vec_optimized.py @@ -44,6 +44,7 @@ import numpy as np import tensorflow as tf from tensorflow.models.embedding import gen_word2vec as word2vec +from tensorflow.python.util import compat flags = tf.app.flags @@ -158,11 +159,11 @@ class Word2Vec(object): """ questions = [] questions_skipped = 0 - with open(self._options.eval_data) as analogy_f: + with open(self._options.eval_data, "rb") as analogy_f: for line in analogy_f: - if line.startswith(":"): # Skip comments. + if line.startswith(b":"): # Skip comments. continue - words = line.strip().lower().split(" ") + words = line.strip().lower().split(b" ") ids = [self._word2id.get(w.strip()) for w in words] if None in ids or len(ids) != 4: questions_skipped += 1 @@ -240,7 +241,8 @@ class Word2Vec(object): opts = self._options with open(os.path.join(opts.save_path, "vocab.txt"), "w") as f: for i in xrange(opts.vocab_size): - f.write(opts.vocab_words[i] + " " + str(opts.vocab_counts[i]) + "\n") + f.write("%s %d\n" % (compat.as_text(opts.vocab_words[i]), + opts.vocab_counts[i])) def build_eval_graph(self): """Build the evaluation graph.""" diff --git a/tensorflow/models/image/alexnet/alexnet_benchmark.py b/tensorflow/models/image/alexnet/alexnet_benchmark.py index 2dd7b68015e..1f8c4df110a 100644 --- a/tensorflow/models/image/alexnet/alexnet_benchmark.py +++ b/tensorflow/models/image/alexnet/alexnet_benchmark.py @@ -35,10 +35,10 @@ from __future__ import print_function from datetime import datetime import math -from six.moves import xrange # pylint: disable=redefined-builtin import time import tensorflow.python.platform +from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow as tf diff --git a/tensorflow/models/image/cifar10/cifar10_input_test.py b/tensorflow/models/image/cifar10/cifar10_input_test.py index a378e431ac0..ed254f2a0de 100644 --- a/tensorflow/models/image/cifar10/cifar10_input_test.py +++ b/tensorflow/models/image/cifar10/cifar10_input_test.py @@ -26,14 +26,15 @@ import tensorflow.python.platform import tensorflow as tf from tensorflow.models.image.cifar10 import cifar10_input +from tensorflow.python.util import compat class CIFAR10InputTest(tf.test.TestCase): def _record(self, label, red, green, blue): image_size = 32 * 32 - record = "%s%s%s%s" % (chr(label), chr(red) * image_size, - chr(green) * image_size, chr(blue) * image_size) + record = bytes(bytearray([label] + [red] * image_size + + [green] * image_size + [blue] * image_size)) expected = [[[red, green, blue]] * 32] * 32 return record, expected @@ -42,10 +43,10 @@ class CIFAR10InputTest(tf.test.TestCase): records = [self._record(labels[0], 0, 128, 255), self._record(labels[1], 255, 0, 1), self._record(labels[2], 254, 255, 0)] - contents = "".join([record for record, _ in records]) + contents = b"".join([record for record, _ in records]) expected = [expected for _, expected in records] filename = os.path.join(self.get_temp_dir(), "cifar") - open(filename, "w").write(contents) + open(filename, "wb").write(contents) with self.test_session() as sess: q = tf.FIFOQueue(99, [tf.string], shapes=()) @@ -56,7 +57,7 @@ class CIFAR10InputTest(tf.test.TestCase): for i in range(3): key, label, uint8image = sess.run([ result.key, result.label, result.uint8image]) - self.assertEqual("%s:%d" % (filename, i), key) + self.assertEqual("%s:%d" % (filename, i), compat.as_text(key)) self.assertEqual(labels[i], label) self.assertAllEqual(expected[i], uint8image) diff --git a/tensorflow/models/rnn/linear_test.py b/tensorflow/models/rnn/linear_test.py index c504fcad68b..22c38434133 100644 --- a/tensorflow/models/rnn/linear_test.py +++ b/tensorflow/models/rnn/linear_test.py @@ -40,7 +40,7 @@ class LinearTest(tf.test.TestCase): # Checks prevent you from accidentally creating a shared function. with self.assertRaises(ValueError) as exc: l1 = linear.linear([x], 2, False) - self.assertEqual(exc.exception.message[:12], "Over-sharing") + self.assertEqual(str(exc.exception)[:12], "Over-sharing") # But you can create a new one in a new scope and share the variables. with tf.variable_scope("l1") as new_scope: diff --git a/tensorflow/models/rnn/rnn.py b/tensorflow/models/rnn/rnn.py index d733de5eab4..b95bf98f723 100644 --- a/tensorflow/models/rnn/rnn.py +++ b/tensorflow/models/rnn/rnn.py @@ -97,17 +97,18 @@ def rnn(cell, inputs, initial_state=None, dtype=None, state.dtype)) max_sequence_length = tf.reduce_max(sequence_length) - output_state = (None, None) for time, input_ in enumerate(inputs): - if time > 0: - tf.get_variable_scope().reuse_variables() - output_state = cell(input_, state) + if time > 0: tf.get_variable_scope().reuse_variables() + # pylint: disable=cell-var-from-loop + def output_state(): + return cell(input_, state) + # pylint: enable=cell-var-from-loop if sequence_length: (output, state) = control_flow_ops.cond( time >= max_sequence_length, - lambda: zero_output_state, lambda: output_state) + lambda: zero_output_state, output_state) else: - (output, state) = output_state + (output, state) = output_state() outputs.append(output) states.append(state) @@ -122,7 +123,7 @@ def state_saving_rnn(cell, inputs, state_saver, state_name, Args: cell: An instance of RNNCell. inputs: A length T list of inputs, each a vector with shape [batch_size]. - state_saver: A StateSaver object. + state_saver: A state saver object with methods `state` and `save_state`. state_name: The name to use with the state_saver. sequence_length: (optional) An int64 vector (tensor) size [batch_size]. See the documentation for rnn() for more details about sequence_length. @@ -137,10 +138,10 @@ def state_saving_rnn(cell, inputs, state_saver, state_name, TypeError: If "cell" is not an instance of RNNCell. ValueError: If inputs is None or an empty list. """ - initial_state = state_saver.State(state_name) + initial_state = state_saver.state(state_name) (outputs, states) = rnn(cell, inputs, initial_state=initial_state, sequence_length=sequence_length, scope=scope) - save_state = state_saver.SaveState(state_name, states[-1]) + save_state = state_saver.save_state(state_name, states[-1]) with tf.control_dependencies([save_state]): outputs[-1] = tf.identity(outputs[-1]) diff --git a/tensorflow/models/rnn/rnn_test.py b/tensorflow/models/rnn/rnn_test.py index a377108f4d7..108a615f9af 100644 --- a/tensorflow/models/rnn/rnn_test.py +++ b/tensorflow/models/rnn/rnn_test.py @@ -50,10 +50,10 @@ class TestStateSaver(object): self._batch_size = batch_size self._state_size = state_size - def State(self, _): + def state(self, _): return tf.zeros(tf.pack([self._batch_size, self._state_size])) - def SaveState(self, _, state): + def save_state(self, _, state): self.saved_state = state return tf.identity(state) @@ -379,7 +379,7 @@ class LSTMTest(tf.test.TestCase): self.assertEqual(len(outputs), len(inputs)) self.assertEqual(len(outputs), len(states)) - tf.initialize_all_variables().run() + tf.initialize_all_variables().run(feed_dict={sequence_length: [2, 3]}) input_value = np.asarray(np.random.randn(batch_size, input_size), dtype=np.float64) values = sess.run(outputs, feed_dict={inputs[0]: input_value, diff --git a/tensorflow/models/rnn/seq2seq.py b/tensorflow/models/rnn/seq2seq.py index 56eb08e86c4..77782ee9347 100644 --- a/tensorflow/models/rnn/seq2seq.py +++ b/tensorflow/models/rnn/seq2seq.py @@ -210,8 +210,8 @@ def embedding_rnn_seq2seq(encoder_inputs, decoder_inputs, cell, encoder state, on embedded decoder_inputs. Args: - encoder_inputs: a list of 2D Tensors [batch_size x cell.input_size]. - decoder_inputs: a list of 2D Tensors [batch_size x cell.input_size]. + encoder_inputs: a list of 1D int32-Tensors of shape [batch_size]. + decoder_inputs: a list of 1D int32-Tensors of shape [batch_size]. cell: rnn_cell.RNNCell defining the cell function and size. num_encoder_symbols: integer; number of symbols on the encoder side. num_decoder_symbols: integer; number of symbols on the decoder side. diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index e6cccbd715e..76126b30435 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -133,7 +133,7 @@ py_library( "framework/random_seed.py", "framework/registry.py", "framework/tensor_shape.py", - "framework/types.py", + "framework/dtypes.py", "framework/tensor_util.py", "ops/common_shapes.py", ], @@ -282,9 +282,9 @@ py_test( ) py_test( - name = "framework_types_test", - srcs = ["framework/types_test.py"], - main = "framework/types_test.py", + name = "framework_dtypes_test", + srcs = ["framework/dtypes_test.py"], + main = "framework/dtypes_test.py", srcs_version = "PY2AND3", deps = [ ":framework_test_lib", @@ -335,6 +335,7 @@ tf_gen_op_wrapper_py( "AllCandidateSampler", "ComputeAccidentalHits", "FixedUnigramCandidateSampler", + "LearnedUnigramCandidateSampler", "LogUniformCandidateSampler", "ThreadUnsafeUnigramCandidateSampler", "UniformCandidateSampler", @@ -374,6 +375,10 @@ tf_gen_op_wrapper_py( "QueueEnqueueMany", "QueueSize", "RandomShuffleQueue", + "Stack", + "StackPop", + "StackPush", + "StackClose", ], require_shape_functions = True, ) @@ -567,6 +572,7 @@ py_library( "ops/gen_string_ops.py", "ops/gen_summary_ops.py", "ops/gradients.py", + "ops/image_grad.py", "ops/image_ops.py", "ops/init_ops.py", "ops/io_ops.py", @@ -944,6 +950,7 @@ py_library( ":client", ":framework", ":pywrap_tensorflow", + ":util", "//tensorflow/core:protos_all_py", ], ) diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py index 7f93d623d39..1932b224cd5 100644 --- a/tensorflow/python/__init__.py +++ b/tensorflow/python/__init__.py @@ -28,14 +28,17 @@ import tensorflow as tf """ +import traceback + try: + # pylint: disable=g-import-not-at-top import tensorflow.python.platform from tensorflow.core.framework.graph_pb2 import * -except ImportError as e: - msg = """Error importing tensorflow: you should not try to import - tensorflow from its source directory; please exit the tensorflow source tree, - and relaunch your python interpreter from there. - Original ImportError: %s""" % str(e) +except ImportError: + msg = """%s\n\nError importing tensorflow. Unless you are using bazel, +you should not try to import tensorflow from its source directory; +please exit the tensorflow source tree, and relaunch your python interpreter +from there.""" % traceback.format_exc() raise ImportError(msg) from tensorflow.core.framework.summary_pb2 import * diff --git a/tensorflow/python/client/events_writer.i b/tensorflow/python/client/events_writer.i index cbf42e27910..5d0f5cbc1fe 100644 --- a/tensorflow/python/client/events_writer.i +++ b/tensorflow/python/client/events_writer.i @@ -1,3 +1,19 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +%include "tensorflow/python/lib/core/strings.i" %include "tensorflow/python/platform/base.i" %{ diff --git a/tensorflow/python/client/events_writer_test.py b/tensorflow/python/client/events_writer_test.py index 2a7a750b5b9..25c55e495a3 100644 --- a/tensorflow/python/client/events_writer_test.py +++ b/tensorflow/python/client/events_writer_test.py @@ -26,14 +26,15 @@ from tensorflow.python import pywrap_tensorflow from tensorflow.python.lib.io import tf_record from tensorflow.python.framework import test_util from tensorflow.python.platform import googletest +from tensorflow.python.util import compat class PywrapeventsWriterTest(test_util.TensorFlowTestCase): def testWriteEvents(self): file_prefix = os.path.join(self.get_temp_dir(), "events") - writer = pywrap_tensorflow.EventsWriter(file_prefix) - filename = writer.FileName() + writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(file_prefix)) + filename = compat.as_text(writer.FileName()) event_written = event_pb2.Event( wall_time=123.45, step=67, summary=summary_pb2.Summary( @@ -66,7 +67,7 @@ class PywrapeventsWriterTest(test_util.TensorFlowTestCase): class _Invalid(object): def __str__(self): return "Invalid" with self.assertRaisesRegexp(TypeError, "Invalid"): - pywrap_tensorflow.EventsWriter("foo").WriteEvent(_Invalid()) + pywrap_tensorflow.EventsWriter(b"foo").WriteEvent(_Invalid()) if __name__ == "__main__": diff --git a/tensorflow/python/client/graph_util.py b/tensorflow/python/client/graph_util.py index 96541864401..31b2dddc23c 100644 --- a/tensorflow/python/client/graph_util.py +++ b/tensorflow/python/client/graph_util.py @@ -24,8 +24,8 @@ import tensorflow.python.platform from tensorflow.core.framework import graph_pb2 from tensorflow.python.framework import device as pydev +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import types from tensorflow.python.platform import logging _VARIABLE_OPS = { @@ -90,18 +90,18 @@ def must_run_on_cpu(node, pin_variables_on_cpu=False): if node_def.op == "Const": # Get the value of the 'dtype' attr dtype = node_def.attr["dtype"].type - if dtype == types.string or dtype == types.int32: + if dtype == dtypes.string or dtype == dtypes.int32: return True if node_def.op == "DynamicStitch": dtype = node_def.attr["T"].type - if dtype == types.int32: + if dtype == dtypes.int32: # DynamicStitch on GPU only works for int32 values. return True if node_def.op in ["Cast"]: dtype = node_def.attr["SrcT"].type - if dtype == types.int32: + if dtype == dtypes.int32: # Cast on GPU does not works for int32 values. return True return False diff --git a/tensorflow/python/client/graph_util_test.py b/tensorflow/python/client/graph_util_test.py index ade0786b7b7..6b7dba60bc0 100644 --- a/tensorflow/python/client/graph_util_test.py +++ b/tensorflow/python/client/graph_util_test.py @@ -21,8 +21,8 @@ from __future__ import print_function import tensorflow.python.platform from tensorflow.python.client import graph_util +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import types from tensorflow.python.ops import constant_op from tensorflow.python.ops import data_flow_ops # pylint: disable=unused-import @@ -39,7 +39,7 @@ class DeviceFunctionsTest(googletest.TestCase): const_a = constant_op.constant(5.0) const_b = constant_op.constant(10.0) add_c = const_a + const_b - var_v = state_ops.variable_op([], dtype=types.float32) + var_v = state_ops.variable_op([], dtype=dtypes.float32) assign_c_to_v = state_ops.assign(var_v, add_c) const_string = constant_op.constant("on a cpu") dynamic_stitch_int_result = data_flow_ops.dynamic_stitch( @@ -61,7 +61,7 @@ class DeviceFunctionsTest(googletest.TestCase): const_a = constant_op.constant(5.0) const_b = constant_op.constant(10.0) add_c = const_a + const_b - var_v = state_ops.variable_op([], dtype=types.float32) + var_v = state_ops.variable_op([], dtype=dtypes.float32) assign_c_to_v = state_ops.assign(var_v, add_c) dynamic_stitch_int_result = data_flow_ops.dynamic_stitch( [[0, 1, 2], [2, 3]], [[12, 23, 34], [1, 2]]) @@ -77,16 +77,16 @@ class DeviceFunctionsTest(googletest.TestCase): def testTwoDeviceFunctions(self): with ops.Graph().as_default() as g: - var_0 = state_ops.variable_op([1], dtype=types.float32) + var_0 = state_ops.variable_op([1], dtype=dtypes.float32) with g.device(graph_util.pin_variables_on_cpu): - var_1 = state_ops.variable_op([1], dtype=types.float32) - var_2 = state_ops.variable_op([1], dtype=types.float32) - var_3 = state_ops.variable_op([1], dtype=types.float32) + var_1 = state_ops.variable_op([1], dtype=dtypes.float32) + var_2 = state_ops.variable_op([1], dtype=dtypes.float32) + var_3 = state_ops.variable_op([1], dtype=dtypes.float32) with g.device(graph_util.pin_variables_on_cpu): - var_4 = state_ops.variable_op([1], dtype=types.float32) + var_4 = state_ops.variable_op([1], dtype=dtypes.float32) with g.device("/device:GPU:0"): - var_5 = state_ops.variable_op([1], dtype=types.float32) - var_6 = state_ops.variable_op([1], dtype=types.float32) + var_5 = state_ops.variable_op([1], dtype=dtypes.float32) + var_6 = state_ops.variable_op([1], dtype=dtypes.float32) self.assertEqual(var_0.device, None) self.assertEqual(var_1.device, "/device:CPU:0") diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py index 882ce903551..8273d4f49dd 100644 --- a/tensorflow/python/client/session.py +++ b/tensorflow/python/client/session.py @@ -32,6 +32,7 @@ from tensorflow.python import pywrap_tensorflow as tf_session from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.platform import logging +from tensorflow.python.util import compat class SessionInterface(object): @@ -52,56 +53,6 @@ class SessionInterface(object): raise NotImplementedError('Run') -def _as_bytes(bytes_or_unicode): - """Returns the given argument as a byte array. - - NOTE(mrry): For Python 2 and 3 compatibility, we convert all string - arguments to SWIG methods into byte arrays. Unicode strings are - encoded as UTF-8; however the valid arguments for all of the - human-readable arguments must currently be a subset of ASCII. - - Args: - bytes_or_unicode: A `unicode`, `string`, or `bytes` object. - - Returns: - A `bytes` object. - - Raises: - TypeError: If `bytes_or_unicode` is not a binary or unicode string. - """ - if isinstance(bytes_or_unicode, six.text_type): - return bytes_or_unicode.encode('utf-8') - elif isinstance(bytes_or_unicode, six.binary_type): - return bytes_or_unicode - else: - raise TypeError('bytes_or_unicode must be a binary or unicode string.') - - -def _as_text(bytes_or_unicode): - """Returns the given argument as a unicode string. - - NOTE(mrry): For Python 2 and 3 compatibility, we interpret all - returned strings from SWIG methods as byte arrays. This function - converts those strings that are intended to be human-readable into - UTF-8 unicode strings. - - Args: - bytes_or_unicode: A `unicode`, `string`, or `bytes` object. - - Returns: - A `unicode` (Python 2) or `str` (Python 3) object. - - Raises: - TypeError: If `bytes_or_unicode` is not a binary or unicode string. - """ - if isinstance(bytes_or_unicode, six.text_type): - return bytes_or_unicode - elif isinstance(bytes_or_unicode, six.binary_type): - return bytes_or_unicode.decode('utf-8') - else: - raise TypeError('bytes_or_unicode must be a binary or unicode string.') - - class BaseSession(SessionInterface): """A class for interacting with a TensorFlow computation. @@ -136,16 +87,17 @@ class BaseSession(SessionInterface): self._session = None + opts = tf_session.TF_NewSessionOptions(target=target, config=config) try: - opts = tf_session.TF_NewSessionOptions(target=target, config=config) status = tf_session.TF_NewStatus() - self._session = tf_session.TF_NewSession(opts, status) - if tf_session.TF_GetCode(status) != 0: - raise RuntimeError(_as_text(tf_session.TF_Message(status))) - + try: + self._session = tf_session.TF_NewSession(opts, status) + if tf_session.TF_GetCode(status) != 0: + raise RuntimeError(compat.as_text(tf_session.TF_Message(status))) + finally: + tf_session.TF_DeleteStatus(status) finally: tf_session.TF_DeleteSessionOptions(opts) - tf_session.TF_DeleteStatus(status) def close(self): """Closes this session. @@ -162,7 +114,7 @@ class BaseSession(SessionInterface): status = tf_session.TF_NewStatus() tf_session.TF_CloseSession(self._session, status) if tf_session.TF_GetCode(status) != 0: - raise RuntimeError(_as_text(tf_session.TF_Message(status))) + raise RuntimeError(compat.as_text(tf_session.TF_Message(status))) finally: tf_session.TF_DeleteStatus(status) @@ -173,7 +125,7 @@ class BaseSession(SessionInterface): if self._session is not None: tf_session.TF_DeleteSession(self._session, status) if tf_session.TF_GetCode(status) != 0: - raise RuntimeError(_as_text(tf_session.TF_Message(status))) + raise RuntimeError(compat.as_text(tf_session.TF_Message(status))) self._session = None finally: tf_session.TF_DeleteStatus(status) @@ -369,19 +321,19 @@ class BaseSession(SessionInterface): fetch_t = self.graph.as_graph_element(subfetch, allow_tensor=True, allow_operation=True) if isinstance(fetch_t, ops.Operation): - target_list.append(_as_bytes(fetch_t.name)) + target_list.append(compat.as_bytes(fetch_t.name)) else: - subfetch_names.append(_as_bytes(fetch_t.name)) + subfetch_names.append(compat.as_bytes(fetch_t.name)) except TypeError as e: raise TypeError('Fetch argument %r of %r has invalid type %r, ' 'must be a string or Tensor. (%s)' - % (subfetch, fetch, type(subfetch), e.message)) + % (subfetch, fetch, type(subfetch), str(e))) except ValueError as e: raise ValueError('Fetch argument %r of %r cannot be interpreted as a ' - 'Tensor. (%s)' % (subfetch, fetch, e.message)) + 'Tensor. (%s)' % (subfetch, fetch, str(e))) except KeyError as e: raise ValueError('Fetch argument %r of %r cannot be interpreted as a ' - 'Tensor. (%s)' % (subfetch, fetch, e.message)) + 'Tensor. (%s)' % (subfetch, fetch, str(e))) unique_fetch_targets.update(subfetch_names) fetch_info.append((subfetch_names, fetch_contraction_fn)) @@ -410,7 +362,7 @@ class BaseSession(SessionInterface): 'which has shape %r' % (np_val.shape, subfeed_t.name, tuple(subfeed_t.get_shape().dims))) - feed_dict_string[_as_bytes(subfeed_t.name)] = np_val + feed_dict_string[compat.as_bytes(subfeed_t.name)] = np_val # Run request and get response. results = self._do_run(target_list, unique_fetch_targets, feed_dict_string) @@ -465,7 +417,7 @@ class BaseSession(SessionInterface): tf_session.TF_ExtendGraph( self._session, graph_def.SerializeToString(), status) if tf_session.TF_GetCode(status) != 0: - raise RuntimeError(_as_text(tf_session.TF_Message(status))) + raise RuntimeError(compat.as_text(tf_session.TF_Message(status))) self._opened = True finally: tf_session.TF_DeleteStatus(status) @@ -477,7 +429,7 @@ class BaseSession(SessionInterface): except tf_session.StatusNotOK as e: e_type, e_value, e_traceback = sys.exc_info() - error_message = _as_text(e.error_message) + error_message = compat.as_text(e.error_message) m = BaseSession._NODEDEF_NAME_RE.search(error_message) if m is not None: node_name = m.group(1) @@ -491,7 +443,7 @@ class BaseSession(SessionInterface): raise errors._make_specific_exception(node_def, op, error_message, e.code) # pylint: enable=protected-access - raise e_type, e_value, e_traceback + six.reraise(e_type, e_value, e_traceback) class Session(BaseSession): diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py index f83f1a7829a..7dbc5aad875 100644 --- a/tensorflow/python/client/session_test.py +++ b/tensorflow/python/client/session_test.py @@ -30,10 +30,10 @@ from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.core.framework import config_pb2 from tensorflow.core.lib.core import error_codes_pb2 from tensorflow.python.client import session +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.framework import test_util -from tensorflow.python.framework import types from tensorflow.python.ops import array_ops from tensorflow.python.ops import constant_op from tensorflow.python.ops import control_flow_ops @@ -41,6 +41,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variables from tensorflow.python.platform import googletest +from tensorflow.python.util import compat # NOTE(mrry): Dummy shape registration for op used in the tests. @@ -96,7 +97,7 @@ class SessionTest(test_util.TensorFlowTestCase): def testErrorPayload(self): with session.Session(): - a = array_ops.placeholder(types.float32) + a = array_ops.placeholder(dtypes.float32) with self.assertRaisesOpError(lambda e: e.op == a.op): a.eval() @@ -120,7 +121,7 @@ class SessionTest(test_util.TensorFlowTestCase): with sess.graph._original_op(a.op): b = array_ops.identity(a, name='id') with sess.graph._original_op(b.op): - c = array_ops.placeholder(types.float32) + c = array_ops.placeholder(dtypes.float32) # pylint: enable=protected-access def exc_predicate(e): @@ -514,15 +515,15 @@ class SessionTest(test_util.TensorFlowTestCase): def testFeedAndFetch(self): with session.Session(): - for dtype in [types.float32, - types.float64, - types.int32, - types.uint8, - types.int16, - types.int8, - types.int64, - types.bool, - types.complex64]: + for dtype in [dtypes.float32, + dtypes.float64, + dtypes.int32, + dtypes.uint8, + dtypes.int16, + dtypes.int8, + dtypes.int64, + dtypes.bool, + dtypes.complex64]: for shape in [(32, 4, 128), (37,), (2, 0, 6), (0, 0, 0)]: np_dtype = dtype.as_numpy_dtype @@ -531,9 +532,9 @@ class SessionTest(test_util.TensorFlowTestCase): np_array = np.random.randint(-10, 10, shape) - if dtype == types.bool: + if dtype == dtypes.bool: np_array = np_array > 0 - elif dtype == types.complex64: + elif dtype == dtypes.complex64: np_array = np.sqrt(np_array.astype(np_dtype)) else: np_array = np_array.astype(np_dtype) @@ -547,7 +548,7 @@ class SessionTest(test_util.TensorFlowTestCase): size = 1 for s in shape: size *= s - c_list = np.array([str(i) for i in xrange(size)], + c_list = np.array([compat.as_bytes(str(i)) for i in xrange(size)], dtype=np.object).reshape(shape) if size > 0 else [] c = constant_op.constant(c_list) self.assertAllEqual(c.eval(), c_list) @@ -558,16 +559,16 @@ class SessionTest(test_util.TensorFlowTestCase): size = 1 for s in shape: size *= s - c_list = np.array([str(i) for i in xrange(size)], + c_list = np.array([compat.as_bytes(str(i)) for i in xrange(size)], dtype=np.object).reshape(shape) - feed_t = array_ops.placeholder(dtype=types.string, shape=shape) + feed_t = array_ops.placeholder(dtype=dtypes.string, shape=shape) c = array_ops.identity(feed_t) self.assertAllEqual(c.eval(feed_dict={feed_t: c_list}), c_list) def testStringFeedWithNullCharacters(self): with session.Session(): - c_list = ['\n\x01\x00', '\n\x00\x01'] - feed_t = array_ops.placeholder(dtype=types.string, shape=[2]) + c_list = [b'\n\x01\x00', b'\n\x00\x01'] + feed_t = array_ops.placeholder(dtype=dtypes.string, shape=[2]) c = array_ops.identity(feed_t) out = c.eval(feed_dict={feed_t: c_list}) self.assertEqual(c_list[0], out[0]) @@ -576,7 +577,7 @@ class SessionTest(test_util.TensorFlowTestCase): def testStringFeedWithUnicode(self): with session.Session(): c_list = [u'\n\x01\x00', u'\n\x00\x01'] - feed_t = array_ops.placeholder(dtype=types.string, shape=[2]) + feed_t = array_ops.placeholder(dtype=dtypes.string, shape=[2]) c = array_ops.identity(feed_t) out = c.eval(feed_dict={feed_t: c_list}) diff --git a/tensorflow/python/client/tensorflow_server.i b/tensorflow/python/client/tensorflow_server.i index 65b38269615..a0f06137bf5 100644 --- a/tensorflow/python/client/tensorflow_server.i +++ b/tensorflow/python/client/tensorflow_server.i @@ -1,3 +1,18 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + %include "tensorflow/python/platform/base.i" %import(module="tensorflow.python.pywrap_tensorflow") "tensorflow/python/lib/core/status.i" diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i index 4d6a9c1a58f..168a323a7ea 100644 --- a/tensorflow/python/client/tf_session.i +++ b/tensorflow/python/client/tf_session.i @@ -1,9 +1,22 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + %include "tensorflow/python/platform/base.i" %{ -#include "numpy/arrayobject.h" - #include "tensorflow/python/client/tf_session_helper.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/public/status.h" @@ -16,7 +29,7 @@ // Required to use PyArray_* functions. %include "tensorflow/python/platform/numpy.i" %init %{ -import_array(); +tensorflow::ImportNumpy(); %} // Release the Python GIL for the duration of most methods. @@ -160,7 +173,7 @@ import_array(); SWIG_fail; } else { tensorflow::Safe_PyObjectVector out_values_safe; - for (int i = 0; i < $2->size(); ++i) { + for (size_t i = 0; i < $2->size(); ++i) { out_values_safe.emplace_back(tensorflow::make_safe($2->at(i))); } @@ -169,7 +182,7 @@ import_array(); SWIG_fail; } - for (int i = 0; i < $2->size(); ++i) { + for (size_t i = 0; i < $2->size(); ++i) { PyList_SET_ITEM($result, i, $2->at(i)); out_values_safe[i].release(); } diff --git a/tensorflow/python/client/tf_session_helper.cc b/tensorflow/python/client/tf_session_helper.cc index 26e1c158248..ea24466b7e1 100644 --- a/tensorflow/python/client/tf_session_helper.cc +++ b/tensorflow/python/client/tf_session_helper.cc @@ -13,6 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +// We define the PY_ARRAY_UNIQUE_SYMBOL in this .cc file and provide an +// ImportNumpy function to populate it. +#define TF_IMPORT_NUMPY + #include "tensorflow/python/client/tf_session_helper.h" #include <cstring> @@ -80,7 +84,7 @@ Status PyArrayDescr_to_TF_DataType(PyArray_Descr* descr, Status PyArray_TYPE_to_TF_DataType(PyArrayObject* array, TF_DataType* out_tf_datatype) { int pyarray_type = PyArray_TYPE(array); - PyArray_Descr* descr = array->descr; + PyArray_Descr* descr = PyArray_DESCR(array); switch (pyarray_type) { case NPY_FLOAT32: *out_tf_datatype = TF_FLOAT; @@ -185,8 +189,8 @@ Status PyBytesArrayMap(PyArrayObject* array, F f) { Safe_PyObjectPtr iter = tensorflow::make_safe( PyArray_IterNew(reinterpret_cast<PyObject*>(array))); while (PyArray_ITER_NOTDONE(iter.get())) { - auto item = tensorflow::make_safe( - PyArray_GETITEM(array, PyArray_ITER_DATA(iter.get()))); + auto item = tensorflow::make_safe(PyArray_GETITEM( + array, static_cast<char*>(PyArray_ITER_DATA(iter.get())))); if (!item.get()) { return errors::Internal("Unable to get element from the feed."); } @@ -197,7 +201,7 @@ Status PyBytesArrayMap(PyArrayObject* array, F f) { // Accept unicode in Python 3, by converting to UTF-8 bytes. if (PyUnicode_Check(item.get())) { ptr = PyUnicode_AsUTF8AndSize(item.get(), &len); - if (!buf) { + if (!ptr) { return errors::Internal("Unable to get element from the feed."); } } else { @@ -260,7 +264,8 @@ static Status TF_StringTensor_GetPtrAndLen(const TF_Tensor* src, reinterpret_cast<const tensorflow::uint64*>(input)[i]; const char* p = tensorflow::core::GetVarint64Ptr(data_start + offset, limit, len); - if (offset >= (limit - data_start) || !p || (*len > (limit - p))) { + if (static_cast<int64>(offset) >= (limit - data_start) || !p || + static_cast<int64>(*len) > (limit - p)) { return errors::InvalidArgument("Malformed TF_STRING tensor; element ", i, " out of range"); } @@ -279,8 +284,8 @@ static Status CopyStringToPyArrayElement(PyArrayObject* pyarray, void* i_ptr, TF_RETURN_IF_ERROR( TF_StringTensor_GetPtrAndLen(tensor, num_elements, i, &ptr, &len)); auto py_string = tensorflow::make_safe(PyBytes_FromStringAndSize(ptr, len)); - int success = - PyArray_SETITEM(pyarray, PyArray_ITER_DATA(i_ptr), py_string.get()); + int success = PyArray_SETITEM( + pyarray, static_cast<char*>(PyArray_ITER_DATA(i_ptr)), py_string.get()); if (success != 0) { return errors::Internal("Error setting element ", i); } @@ -323,7 +328,8 @@ Status TF_Tensor_to_PyObject(TF_Tensor* tensor, PyObject** out_array) { } PyArrayObject* py_array = reinterpret_cast<PyArrayObject*>(safe_out_array.get()); - if (PyArray_NBYTES(py_array) != TF_TensorByteSize(tensor)) { + if (PyArray_NBYTES(py_array) != + static_cast<int64>(TF_TensorByteSize(tensor))) { if (TF_TensorType(tensor) == TF_STRING) { // Copy element by element. auto iter = tensorflow::make_safe(PyArray_IterNew(safe_out_array.get())); @@ -341,7 +347,8 @@ Status TF_Tensor_to_PyObject(TF_Tensor* tensor, PyObject** out_array) { TF_TensorByteSize(tensor), " bytes"); } } else { - memcpy(py_array->data, TF_TensorData(tensor), PyArray_NBYTES(py_array)); + memcpy(PyArray_DATA(py_array), TF_TensorData(tensor), + PyArray_NBYTES(py_array)); } // PyArray_Return turns rank 0 arrays into numpy scalars @@ -395,8 +402,6 @@ tensorflow::Status TF_Status_to_Status(TF_Status* tf_status) { } } -static bool numpy_imported = false; - } // namespace Safe_PyObjectPtr make_safe(PyObject* o) { @@ -410,12 +415,6 @@ void TF_Run_wrapper(TF_Session* session, const FeedVector& inputs, const NameVector& output_names, const NameVector& target_nodes, Status* out_status, PyObjectVector* out_values) { - // 0. Ensure that numpy has been imported. - if (!numpy_imported) { - import_array(); - numpy_imported = true; - } - // 1. Convert the feed inputs to the appropriate form for TF_Run. NameVector input_names; Safe_PyObjectVector @@ -428,7 +427,7 @@ void TF_Run_wrapper(TF_Session* session, const FeedVector& inputs, make_safe(reinterpret_cast<PyObject*>(name_and_array.second))); } - for (int i = 0; i < inputs.size(); ++i) { + for (size_t i = 0; i < inputs.size(); ++i) { input_names.push_back(inputs[i].first); PyArrayObject* array = inputs[i].second; @@ -460,7 +459,7 @@ void TF_Run_wrapper(TF_Session* session, const FeedVector& inputs, // requirements for tensorflow::Tensor. We hard code this here to // avoid taking a dependency on Eigen in the client code. void* data = tensorflow::cpu_allocator()->AllocateRaw(32, size); - std::memcpy(data, array->data, size); + std::memcpy(data, PyArray_DATA(array), size); inputs_safe.emplace_back(make_safe( TF_NewTensor(dtype, dims.data(), dims.size(), data, size, [](void* data, size_t len, void* arg) { @@ -525,7 +524,7 @@ void TF_Run_wrapper(TF_Session* session, const FeedVector& inputs, // 6. Convert the fetched tensors into numpy ndarrays. Store them in a safe // container so that we do not leak Safe_PyObjectVector py_outputs_safe; - for (int i = 0; i < output_names.size(); ++i) { + for (size_t i = 0; i < output_names.size(); ++i) { PyObject* py_array; *out_status = TF_Tensor_to_PyObject(outputs[i], &py_array); if (!out_status->ok()) { @@ -542,4 +541,6 @@ void TF_Run_wrapper(TF_Session* session, const FeedVector& inputs, *out_status = Status::OK(); } +void ImportNumpy() { import_array1(); } + } // namespace tensorflow diff --git a/tensorflow/python/client/tf_session_helper.h b/tensorflow/python/client/tf_session_helper.h index 80b38415cc3..dfac9f789d5 100644 --- a/tensorflow/python/client/tf_session_helper.h +++ b/tensorflow/python/client/tf_session_helper.h @@ -16,6 +16,19 @@ limitations under the License. #ifndef TENSORFLOW_PYTHON_CLIENT_TF_SESSION_HELPER_H_ #define TENSORFLOW_PYTHON_CLIENT_TF_SESSION_HELPER_H_ +#ifdef PyArray_Type +#error "Numpy cannot be included before tf_session_helper.h." +#endif + +// Disallow Numpy 1.7 deprecated symbols. +#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION + +// We import_array in the tensorflow init function only. +#define PY_ARRAY_UNIQUE_SYMBOL _tensorflow_numpy_api +#ifndef TF_IMPORT_NUMPY +#define NO_IMPORT_ARRAY +#endif + #include <Python.h> #include "numpy/arrayobject.h" @@ -66,6 +79,11 @@ void TF_Run_wrapper(TF_Session* session, const FeedVector& inputs, const NameVector& target_nodes, Status* out_status, PyObjectVector* out_values); +// Import numpy. This wrapper function exists so that the +// PY_ARRAY_UNIQUE_SYMBOL can be safely defined in a .cc file to +// avoid weird linking issues. +void ImportNumpy(); + } // namespace tensorflow #endif // TENSORFLOW_PYTHON_CLIENT_TF_SESSION_HELPER_H_ diff --git a/tensorflow/python/framework/docs.py b/tensorflow/python/framework/docs.py index 5aefc0d120f..145b8049ad2 100644 --- a/tensorflow/python/framework/docs.py +++ b/tensorflow/python/framework/docs.py @@ -238,10 +238,12 @@ class Library(Document): name, member tuples. """ for name, member in inspect.getmembers(cls): - # Only show methods and properties presently. - if not (inspect.ismethod(member) or isinstance(member, property)): + # Only show methods and properties presently. In Python 3, + # methods register as isfunction. + is_method = inspect.ismethod(member) or inspect.isfunction(member) + if not (is_method or isinstance(member, property)): continue - if ((inspect.ismethod(member) and member.__name__ == "__init__") + if ((is_method and member.__name__ == "__init__") or self._should_include_member(name, member)): yield name, ("%s.%s" % (cls_name, name), member) @@ -371,27 +373,19 @@ class Library(Document): self._print_formatted_docstring(inspect.getdoc(func), f) print("", file=f) - def _write_member_markdown_to_file(self, f, name, member): + def _write_member_markdown_to_file(self, f, prefix, name, member): """Print `member` to `f`.""" - if inspect.isfunction(member): + if (inspect.isfunction(member) or inspect.ismethod(member) or + isinstance(member, property)): print("- - -", file=f) print("", file=f) - self._print_function(f, "###", name, member) + self._print_function(f, prefix, name, member) print("", file=f) - elif inspect.ismethod(member): - print("- - -", file=f) - print("", file=f) - self._print_function(f, "####", name, member) - print("", file=f) - elif isinstance(member, property): - print("- - -", file=f) - print("", file=f) - self._print_function(f, "####", name, member) elif inspect.isclass(member): print("- - -", file=f) print("", file=f) - print("### `class %s` {#%s}" % (name, - _get_anchor(self._module_to_name, name)), + print("%s `class %s` {#%s}" % (prefix, name, + _get_anchor(self._module_to_name, name)), file=f) print("", file=f) self._write_class_markdown_to_file(f, name, member) @@ -399,14 +393,15 @@ class Library(Document): else: raise RuntimeError("Member %s has unknown type %s" % (name, type(member))) - def _write_docstring_markdown_to_file(self, f, docstring, members, imports): + def _write_docstring_markdown_to_file(self, f, prefix, docstring, members, + imports): for l in self._remove_docstring_indent(docstring): if l.startswith(_member_mark): name = l[len(_member_mark):].strip(" \t") if name in members: self._documented.add(name) self._mentioned.add(name) - self._write_member_markdown_to_file(f, *members[name]) + self._write_member_markdown_to_file(f, prefix, *members[name]) del members[name] elif name in imports: self._write_module_markdown_to_file(f, imports[name]) @@ -429,7 +424,11 @@ class Library(Document): # Used later to check if any methods were called out in the class # docstring. num_methods = len(methods) - self._write_docstring_markdown_to_file(f, inspect.getdoc(cls), methods, {}) + try: + self._write_docstring_markdown_to_file(f, "####", inspect.getdoc(cls), + methods, {}) + except ValueError as e: + raise ValueError(str(e) + " in class `%s`" % cls.__name__) # If some methods were not described, describe them now if they are # defined by the class itself (not inherited). If NO methods were @@ -445,11 +444,11 @@ class Library(Document): else: other_methods = methods for name in sorted(other_methods): - self._write_member_markdown_to_file(f, *other_methods[name]) + self._write_member_markdown_to_file(f, "####", *other_methods[name]) def _write_module_markdown_to_file(self, f, module): imports = dict(self.get_imported_modules(module)) - self._write_docstring_markdown_to_file(f, inspect.getdoc(module), + self._write_docstring_markdown_to_file(f, "###", inspect.getdoc(module), self._members, imports) def write_markdown_to_file(self, f): @@ -496,7 +495,7 @@ class Library(Document): print(" %s" % name) self._documented.add(name) self._mentioned.add(name) - self._write_member_markdown_to_file(f, *self._members[name]) + self._write_member_markdown_to_file(f, "###", *self._members[name]) def assert_no_leftovers(self): """Generate an error if there are leftover members.""" diff --git a/tensorflow/python/framework/types.py b/tensorflow/python/framework/dtypes.py similarity index 99% rename from tensorflow/python/framework/types.py rename to tensorflow/python/framework/dtypes.py index 123dba558f2..48bf6cac00f 100644 --- a/tensorflow/python/framework/types.py +++ b/tensorflow/python/framework/dtypes.py @@ -70,8 +70,8 @@ class DType(object): """Creates a new `DataType`. NOTE(mrry): In normal circumstances, you should not need to - construct a DataType object directly. Instead, use the - types.as_dtype() function. + construct a `DataType` object directly. Instead, use the + `tf.as_dtype()` function. Args: type_enum: A `types_pb2.DataType` enum value. @@ -223,6 +223,9 @@ class DType(object): def __repr__(self): return "tf." + self.name + def __hash__(self): + return self._type_enum + # Define standard wrappers for the types_pb2.DataType enum. float32 = DType(types_pb2.DT_FLOAT) diff --git a/tensorflow/python/framework/dtypes_test.py b/tensorflow/python/framework/dtypes_test.py new file mode 100644 index 00000000000..6a052084107 --- /dev/null +++ b/tensorflow/python/framework/dtypes_test.py @@ -0,0 +1,205 @@ +# Copyright 2015 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for tensorflow.python.framework.importer.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow.python.platform + +import numpy as np +import tensorflow as tf + +from tensorflow.core.framework import types_pb2 +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util +from tensorflow.python.platform import googletest + + +class TypesTest(test_util.TensorFlowTestCase): + + def testAllTypesConstructible(self): + for datatype_enum in types_pb2.DataType.values(): + if datatype_enum == types_pb2.DT_INVALID: + continue + self.assertEqual( + datatype_enum, tf.DType(datatype_enum).as_datatype_enum) + + def testAllTypesConvertibleToDType(self): + for datatype_enum in types_pb2.DataType.values(): + if datatype_enum == types_pb2.DT_INVALID: + continue + self.assertEqual( + datatype_enum, tf.as_dtype(datatype_enum).as_datatype_enum) + + def testAllTypesConvertibleToNumpyDtype(self): + for datatype_enum in types_pb2.DataType.values(): + if datatype_enum == types_pb2.DT_INVALID: + continue + dtype = tf.as_dtype(datatype_enum) + numpy_dtype = dtype.as_numpy_dtype + _ = np.empty((1, 1, 1, 1), dtype=numpy_dtype) + if dtype.base_dtype != tf.bfloat16: + # NOTE(touts): Intentionally no way to feed a DT_BFLOAT16. + self.assertEqual(tf.as_dtype(datatype_enum).base_dtype, + tf.as_dtype(numpy_dtype)) + + def testInvalid(self): + with self.assertRaises(TypeError): + tf.DType(types_pb2.DT_INVALID) + with self.assertRaises(TypeError): + tf.as_dtype(types_pb2.DT_INVALID) + + def testNumpyConversion(self): + self.assertIs(tf.float32, tf.as_dtype(np.float32)) + self.assertIs(tf.float64, tf.as_dtype(np.float64)) + self.assertIs(tf.int32, tf.as_dtype(np.int32)) + self.assertIs(tf.int64, tf.as_dtype(np.int64)) + self.assertIs(tf.uint8, tf.as_dtype(np.uint8)) + self.assertIs(tf.int16, tf.as_dtype(np.int16)) + self.assertIs(tf.int8, tf.as_dtype(np.int8)) + self.assertIs(tf.complex64, tf.as_dtype(np.complex64)) + self.assertIs(tf.string, tf.as_dtype(np.object)) + self.assertIs(tf.string, tf.as_dtype(np.array(["foo", "bar"]).dtype)) + self.assertIs(tf.bool, tf.as_dtype(np.bool)) + with self.assertRaises(TypeError): + tf.as_dtype(np.dtype([("f1", np.uint), ("f2", np.int32)])) + + def testStringConversion(self): + self.assertIs(tf.float32, tf.as_dtype("float32")) + self.assertIs(tf.float64, tf.as_dtype("float64")) + self.assertIs(tf.int32, tf.as_dtype("int32")) + self.assertIs(tf.uint8, tf.as_dtype("uint8")) + self.assertIs(tf.int16, tf.as_dtype("int16")) + self.assertIs(tf.int8, tf.as_dtype("int8")) + self.assertIs(tf.string, tf.as_dtype("string")) + self.assertIs(tf.complex64, tf.as_dtype("complex64")) + self.assertIs(tf.int64, tf.as_dtype("int64")) + self.assertIs(tf.bool, tf.as_dtype("bool")) + self.assertIs(tf.qint8, tf.as_dtype("qint8")) + self.assertIs(tf.quint8, tf.as_dtype("quint8")) + self.assertIs(tf.qint32, tf.as_dtype("qint32")) + self.assertIs(tf.bfloat16, tf.as_dtype("bfloat16")) + self.assertIs(tf.float32_ref, tf.as_dtype("float32_ref")) + self.assertIs(tf.float64_ref, tf.as_dtype("float64_ref")) + self.assertIs(tf.int32_ref, tf.as_dtype("int32_ref")) + self.assertIs(tf.uint8_ref, tf.as_dtype("uint8_ref")) + self.assertIs(tf.int16_ref, tf.as_dtype("int16_ref")) + self.assertIs(tf.int8_ref, tf.as_dtype("int8_ref")) + self.assertIs(tf.string_ref, tf.as_dtype("string_ref")) + self.assertIs(tf.complex64_ref, tf.as_dtype("complex64_ref")) + self.assertIs(tf.int64_ref, tf.as_dtype("int64_ref")) + self.assertIs(tf.bool_ref, tf.as_dtype("bool_ref")) + self.assertIs(tf.qint8_ref, tf.as_dtype("qint8_ref")) + self.assertIs(tf.quint8_ref, tf.as_dtype("quint8_ref")) + self.assertIs(tf.qint32_ref, tf.as_dtype("qint32_ref")) + self.assertIs(tf.bfloat16_ref, tf.as_dtype("bfloat16_ref")) + with self.assertRaises(TypeError): + tf.as_dtype("not_a_type") + + def testDTypesHaveUniqueNames(self): + dtypes = [] + names = set() + for datatype_enum in types_pb2.DataType.values(): + if datatype_enum == types_pb2.DT_INVALID: + continue + dtype = tf.as_dtype(datatype_enum) + dtypes.append(dtype) + names.add(dtype.name) + self.assertEqual(len(dtypes), len(names)) + + def testIsInteger(self): + self.assertEqual(tf.as_dtype("int8").is_integer, True) + self.assertEqual(tf.as_dtype("int16").is_integer, True) + self.assertEqual(tf.as_dtype("int32").is_integer, True) + self.assertEqual(tf.as_dtype("int64").is_integer, True) + self.assertEqual(tf.as_dtype("uint8").is_integer, True) + self.assertEqual(tf.as_dtype("complex64").is_integer, False) + self.assertEqual(tf.as_dtype("float").is_integer, False) + self.assertEqual(tf.as_dtype("double").is_integer, False) + self.assertEqual(tf.as_dtype("string").is_integer, False) + self.assertEqual(tf.as_dtype("bool").is_integer, False) + + def testIsFloating(self): + self.assertEqual(tf.as_dtype("int8").is_floating, False) + self.assertEqual(tf.as_dtype("int16").is_floating, False) + self.assertEqual(tf.as_dtype("int32").is_floating, False) + self.assertEqual(tf.as_dtype("int64").is_floating, False) + self.assertEqual(tf.as_dtype("uint8").is_floating, False) + self.assertEqual(tf.as_dtype("complex64").is_floating, False) + self.assertEqual(tf.as_dtype("float32").is_floating, True) + self.assertEqual(tf.as_dtype("float64").is_floating, True) + self.assertEqual(tf.as_dtype("string").is_floating, False) + self.assertEqual(tf.as_dtype("bool").is_floating, False) + + def testMinMax(self): + # make sure min/max evaluates for all data types that have min/max + for datatype_enum in types_pb2.DataType.values(): + if datatype_enum == types_pb2.DT_INVALID: + continue + dtype = tf.as_dtype(datatype_enum) + numpy_dtype = dtype.as_numpy_dtype + + # ignore types for which there are no minimum/maximum (or we cannot + # compute it, such as for the q* types) + if (dtype.is_quantized or + dtype.base_dtype == tf.bool or + dtype.base_dtype == tf.string or + dtype.base_dtype == tf.complex64): + continue + + print("%s: %s - %s" % (dtype, dtype.min, dtype.max)) + + # check some values that are known + if numpy_dtype == np.bool_: + self.assertEquals(dtype.min, 0) + self.assertEquals(dtype.max, 1) + if numpy_dtype == np.int8: + self.assertEquals(dtype.min, -128) + self.assertEquals(dtype.max, 127) + if numpy_dtype == np.int16: + self.assertEquals(dtype.min, -32768) + self.assertEquals(dtype.max, 32767) + if numpy_dtype == np.int32: + self.assertEquals(dtype.min, -2147483648) + self.assertEquals(dtype.max, 2147483647) + if numpy_dtype == np.int64: + self.assertEquals(dtype.min, -9223372036854775808) + self.assertEquals(dtype.max, 9223372036854775807) + if numpy_dtype == np.uint8: + self.assertEquals(dtype.min, 0) + self.assertEquals(dtype.max, 255) + if numpy_dtype == np.uint16: + self.assertEquals(dtype.min, 0) + self.assertEquals(dtype.max, 4294967295) + if numpy_dtype == np.uint32: + self.assertEquals(dtype.min, 0) + self.assertEquals(dtype.max, 18446744073709551615) + if numpy_dtype in (np.float16, np.float32, np.float64): + self.assertEquals(dtype.min, np.finfo(numpy_dtype).min) + self.assertEquals(dtype.max, np.finfo(numpy_dtype).max) + + def testRepr(self): + for enum, name in dtypes._TYPE_TO_STRING.items(): + dtype = tf.DType(enum) + self.assertEquals(repr(dtype), 'tf.' + name) + dtype2 = eval(repr(dtype)) + self.assertEquals(type(dtype2), tf.DType) + self.assertEquals(dtype, dtype2) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/python/framework/framework_lib.py b/tensorflow/python/framework/framework_lib.py index 9803401e2c7..e85a72e0e1f 100644 --- a/tensorflow/python/framework/framework_lib.py +++ b/tensorflow/python/framework/framework_lib.py @@ -86,4 +86,4 @@ from tensorflow.python.framework.ops import RegisterShape from tensorflow.python.framework.tensor_shape import Dimension from tensorflow.python.framework.tensor_shape import TensorShape -from tensorflow.python.framework.types import * +from tensorflow.python.framework.dtypes import * diff --git a/tensorflow/python/framework/importer.py b/tensorflow/python/framework/importer.py index ee3d130f626..e52a0e20373 100644 --- a/tensorflow/python/framework/importer.py +++ b/tensorflow/python/framework/importer.py @@ -22,13 +22,12 @@ import contextlib import tensorflow.python.platform -import six - from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import types_pb2 from tensorflow.python.framework import op_def_registry +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import types as types_lib +from tensorflow.python.util import compat # TODO(josh11b): SWIG the code from node_def_util instead of duplicating @@ -61,7 +60,7 @@ def _ArgToTypesNoRef(node_def, arg_def): def _SingleArgToTypes(node_def, arg_def): types = _ArgToTypesNoRef(node_def, arg_def) if arg_def.is_ref: - return [types_lib.as_dtype(dt).as_ref.as_datatype_enum for dt in types] + return [dtypes.as_dtype(dt).as_ref.as_datatype_enum for dt in types] return types @@ -122,6 +121,7 @@ def _ParseTensorName(tensor_name): def _CanonicalInputName(input_name): + input_name = compat.as_str(input_name) if _IsControlInput(input_name): return input_name input_op_name, output_index = _ParseTensorName(input_name) @@ -170,11 +170,11 @@ def import_graph_def(graph_def, input_map=None, return_elements=None, Returns: A list of `Operation` and/or `Tensor` objects from the imported graph, - corresponding to the names in `return_elements'. + corresponding to the names in `return_elements`. Raises: TypeError: If `graph_def` is not a `GraphDef` proto, - `input_map' is not a dictionary mapping strings to `Tensor` objects, + `input_map` is not a dictionary mapping strings to `Tensor` objects, or `return_elements` is not a list of strings. ValueError: If `input_map`, or `return_elements` contains names that do not appear in `graph_def`, or `graph_def` is not well-formed (e.g. @@ -194,14 +194,15 @@ def import_graph_def(graph_def, input_map=None, return_elements=None, input_map = {} else: if not (isinstance(input_map, dict) - and all(isinstance(k, six.string_types) for k in input_map.keys())): + and all(isinstance(k, compat.bytes_or_text_types) + for k in input_map.keys())): raise TypeError('input_map must be a dictionary mapping strings to ' 'Tensor objects.') - if (return_elements is not None - and not (isinstance(return_elements, (list, tuple)) - and all(isinstance(x, six.string_types) - for x in return_elements))): - raise TypeError('return_elements must be a list of strings.') + if return_elements is not None: + return_elements = tuple(return_elements) + if not all(isinstance(x, compat.bytes_or_text_types) + for x in return_elements): + raise TypeError('return_elements must be a list of strings.') # Use a canonical representation for all tensor names. input_map = {_CanonicalInputName(k): v for k, v in input_map.items()} @@ -219,7 +220,7 @@ def import_graph_def(graph_def, input_map=None, return_elements=None, input_map = {k: ops.convert_to_tensor(v) for k, v in input_map.items()} # NOTE(mrry): We do this in two passes, because there may be a cycle in - # `graph_def'. + # `graph_def`. # 1. Add operations without their inputs. for node in graph_def.node: @@ -267,7 +268,7 @@ def import_graph_def(graph_def, input_map=None, return_elements=None, used_input_keys.add(input_name) else: - # (c) Input should be taken from an op in `graph_def'. + # (c) Input should be taken from an op in `graph_def`. operation_name, output_index = _ParseTensorName(input_name) try: source_op = name_to_op[operation_name] @@ -284,9 +285,8 @@ def import_graph_def(graph_def, input_map=None, return_elements=None, op._add_input(source_tensor, dtype=input_type) # pylint: enable=protected-access except TypeError as te: - raise ValueError( - _InvalidNodeMessage(node, 'Input tensor %r %s' - % (input_name, te.message))) + raise ValueError(_InvalidNodeMessage( + node, 'Input tensor %r %s' % (input_name, te))) # pylint: disable=protected_access if op._input_dtypes != input_types: @@ -294,7 +294,7 @@ def import_graph_def(graph_def, input_map=None, return_elements=None, _InvalidNodeMessage( node, 'Input types mismatch (expected %r but got %r)' - % (", ".join(types_lib.as_dtype(x).name for x in input_types), + % (", ".join(dtypes.as_dtype(x).name for x in input_types), ", ".join(x.name for x in op._input_dtypes)))) # pylint: enable=protected_access @@ -316,6 +316,7 @@ def import_graph_def(graph_def, input_map=None, return_elements=None, else: ret = [] for name in return_elements: + name = compat.as_str(name) if ':' in name: try: operation_name, output_index = _ParseTensorName(name) diff --git a/tensorflow/python/framework/importer_test.py b/tensorflow/python/framework/importer_test.py index 95059b17d7e..188ec2edcc0 100644 --- a/tensorflow/python/framework/importer_test.py +++ b/tensorflow/python/framework/importer_test.py @@ -335,7 +335,7 @@ class ImportGraphDefTest(tf.test.TestCase): node { name: 'A' op: 'Oi' } node { name: 'B' op: 'None' input: 'A:0' } """)) - self.assertTrue('More inputs specified (u\'A:0\') than the op expects' in + self.assertTrue('More inputs specified (\'A:0\') than the op expects' in str(e.exception)) def testInvalidSignatureNotEnoughInputsInGraphDef(self): @@ -356,8 +356,7 @@ class ImportGraphDefTest(tf.test.TestCase): self._MakeGraphDef(""" node { name: 'B' op: 'If' input: 'A:0' } """)) - self.assertTrue('Input tensor %r not found' % (u'A:0',) in - str(e.exception)) + self.assertTrue("Input tensor 'A:0' not found" in str(e.exception)) def testMissingInputOpInGraphDefButAppearsInInputMap(self): with tf.Graph().as_default(): @@ -378,8 +377,7 @@ class ImportGraphDefTest(tf.test.TestCase): node { name: 'A' op: 'Of' } node { name: 'B' op: 'If' input: 'A:1' } """)) - self.assertTrue('Input tensor %r not found' % (u'A:1',) in - str(e.exception)) + self.assertTrue("Input tensor 'A:1' not found" in str(e.exception)) def testMissingControlInputInGraphDef(self): with tf.Graph().as_default(): @@ -388,8 +386,7 @@ class ImportGraphDefTest(tf.test.TestCase): self._MakeGraphDef(""" node { name: 'B' op: 'None' input: '^A' } """)) - self.assertTrue('Control input %r not found' % (u'^A',) in - str(e.exception)) + self.assertTrue("Control input '^A' not found" in str(e.exception)) def testInvalidTensorNameOutputIndexInGraphDef(self): with tf.Graph().as_default(): @@ -398,8 +395,8 @@ class ImportGraphDefTest(tf.test.TestCase): self._MakeGraphDef(""" node { name: 'B' op: 'None' input: 'A:B' } """)) - self.assertEqual( - 'Cannot convert %r to a tensor name.' % (u'A:B',), str(e.exception)) + self.assertEqual("Cannot convert 'A:B' to a tensor name.", + str(e.exception)) def testInvalidTensorNameInGraphDef(self): with tf.Graph().as_default(): @@ -408,8 +405,8 @@ class ImportGraphDefTest(tf.test.TestCase): self._MakeGraphDef(""" node { name: 'B' op: 'None' input: 'A:B:0' } """)) - self.assertEqual( - 'Cannot convert %r to a tensor name.' % (u'A:B:0',), str(e.exception)) + self.assertEqual("Cannot convert 'A:B:0' to a tensor name.", + str(e.exception)) def testMissingReturnOperation(self): with tf.Graph().as_default(): @@ -419,7 +416,7 @@ class ImportGraphDefTest(tf.test.TestCase): node { name: 'A' op: 'None' } """), return_elements=['B']) - self.assertTrue('return_element %r not found in graph_def.' % ('B') in + self.assertTrue("return_element 'B' not found in graph_def." in str(e.exception)) def testMissingReturnTensor(self): @@ -430,7 +427,7 @@ class ImportGraphDefTest(tf.test.TestCase): node { name: 'A' op: 'Oi' } """), return_elements=['A:1']) - self.assertTrue('return_element %r not found in graph_def.' % ('A:1') in + self.assertTrue("return_element 'A:1' not found in graph_def." in str(e.exception)) with self.assertRaises(ValueError) as e: @@ -439,7 +436,7 @@ class ImportGraphDefTest(tf.test.TestCase): node { name: 'A' op: 'Oi' } """), return_elements=['B:0']) - self.assertTrue('return_element %r not found in graph_def.' % ('B:0') in + self.assertTrue("return_element 'B:0' not found in graph_def." in str(e.exception)) with self.assertRaises(ValueError) as e: @@ -448,7 +445,7 @@ class ImportGraphDefTest(tf.test.TestCase): node { name: 'A' op: 'Oi' } """), return_elements=['A:B:0']) - self.assertTrue('return_element %r not found in graph_def.' % ('A:B:0') in + self.assertTrue("return_element 'A:B:0' not found in graph_def." in str(e.exception)) def testMissingInputMap(self): diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 00bd776fad4..d66e93300d9 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -34,9 +34,10 @@ import six from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import graph_pb2 from tensorflow.python.framework import device as pydev +from tensorflow.python.framework import dtypes from tensorflow.python.framework import registry from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import types +from tensorflow.python.util import compat def _convert_stack(stack): @@ -195,7 +196,7 @@ class Tensor(object): raise TypeError("op needs to be an Operation: %s" % op) self._op = op self._value_index = value_index - self._dtype = types.as_dtype(dtype) + self._dtype = dtypes.as_dtype(dtype) self._shape = tensor_shape.unknown_shape() # List of operations that use this Tensor as input. We maintain this list # to easily navigate a computation graph. @@ -393,10 +394,11 @@ class Tensor(object): ValueError: If operator has already been overwritten, or if operator is not allowed to be overwritten. """ - if getattr(Tensor, operator, None) is not None: - # check to see if this is a default method-wrapper which will be true - # for the comparison operators. - if not isinstance(getattr(Tensor, operator, None), type(all.__call__)): + existing = getattr(Tensor, operator, None) + if existing is not None: + # Check to see if this is a default method-wrapper or slot wrapper which + # will be true for the comparison operators. + if not isinstance(existing, type(object.__lt__)): raise ValueError("operator %s cannot be overwritten again." % operator) if operator not in Tensor.OVERLOADABLE_OPERATORS: raise ValueError("Overriding %s is disallowed" % operator) @@ -496,7 +498,7 @@ def convert_to_tensor(value, dtype=None, name=None): """ error_prefix = "" if name is None else "%s: " % name if dtype is not None: - dtype = types.as_dtype(dtype) + dtype = dtypes.as_dtype(dtype) for _, funcs_at_priority in sorted(_tensor_conversion_func_registry.items()): for base_type, conversion_func in funcs_at_priority: if isinstance(value, base_type): @@ -538,10 +540,10 @@ def convert_to_tensor_or_indexed_slices(value, dtype=None, name=None): ValueError: If `dtype` does not match the element type of `value`. """ if isinstance(value, IndexedSlices): - if dtype and not types.AsDType(dtype).is_compatible_with(value.dtype): + if dtype and not dtypes.as_dtype(dtype).is_compatible_with(value.dtype): raise ValueError( "Tensor conversion requested dtype %s for Tensor with dtype %s: %r" - % (types.AsDType(dtype).name, value.dtype.name, str(value))) + % (dtypes.as_dtype(dtype).name, value.dtype.name, str(value))) return value else: return convert_to_tensor(value, dtype, name) @@ -880,8 +882,8 @@ def _NodeDef(op_type, name, device=None, attrs=None): A graph_pb2.NodeDef protocol buffer. """ node_def = graph_pb2.NodeDef() - node_def.op = str(op_type) - node_def.name = str(name) + node_def.op = compat.as_bytes(op_type) + node_def.name = compat.as_bytes(name) if attrs is not None: for k, v in six.iteritems(attrs): node_def.attr[k].CopyFrom(v) @@ -966,11 +968,11 @@ class Operation(object): Raises: TypeError: if control inputs are not Operations or Tensors, - or if node_def is not a `NodeDef`, - or if g is not a `Graph`, - or if inputs are not tensors, - or if inputs and input_types are incompatible. - ValueError: if the node_def name is not valid. + or if `node_def` is not a `NodeDef`, + or if `g` is not a `Graph`, + or if `inputs` are not tensors, + or if `inputs` and `input_types` are incompatible. + ValueError: if the `node_def` name is not valid. """ if not isinstance(node_def, graph_pb2.NodeDef): raise TypeError("node_def needs to be a NodeDef: %s" % node_def) @@ -1079,7 +1081,7 @@ class Operation(object): Args: tensor: the Tensor to add as an input. - dtype: types.DType: type of the input; defaults to + dtype: tf.DType: type of the input; defaults to the tensor's dtype. Raises: @@ -1093,7 +1095,7 @@ class Operation(object): if dtype is None: dtype = tensor.dtype else: - dtype = types.as_dtype(dtype) + dtype = dtypes.as_dtype(dtype) if not dtype.is_compatible_with(tensor.dtype): raise TypeError( "Cannot convert a tensor of type %s to an input of type %s" @@ -1111,7 +1113,7 @@ class Operation(object): Args: index: the index of the input to update. tensor: the Tensor to be used as the input at the given index. - dtype: types.DType: type of the input; defaults to + dtype: tf.DType: type of the input; defaults to the tensor's dtype. Raises: @@ -1125,7 +1127,7 @@ class Operation(object): if dtype is None: dtype = tensor.dtype else: - dtype = types.as_dtype(dtype) + dtype = dtypes.as_dtype(dtype) if not dtype.is_compatible_with(tensor.dtype): raise TypeError( "Cannot convert a tensor of type %s to an input of type %s" @@ -1414,7 +1416,7 @@ class RegisterShape(object): """ def __init__(self, op_type): - """Saves the "op_type" as the Operation type.""" + """Saves the `op_type` as the `Operation` type.""" if not isinstance(op_type, six.string_types): raise TypeError("op_type must be a string") self._op_type = op_type @@ -1664,7 +1666,7 @@ class Graph(object): protocol buffer. Raises: - ValueError: If the graph_def would be too large. + ValueError: If the `graph_def` would be too large. """ graph = graph_pb2.GraphDef() bytesize = 0 @@ -1763,7 +1765,7 @@ class Graph(object): try: kernel_label = self._op_to_kernel_label_map[op_type] node_def.attr["_kernel"].CopyFrom( - attr_value_pb2.AttrValue(s=kernel_label)) + attr_value_pb2.AttrValue(s=compat.as_bytes(kernel_label))) except KeyError: pass @@ -1772,7 +1774,7 @@ class Graph(object): try: mapped_op_type = self._gradient_override_map[op_type] node_def.attr["_gradient_op_type"].CopyFrom( - attr_value_pb2.AttrValue(s=mapped_op_type)) + attr_value_pb2.AttrValue(s=compat.as_bytes(mapped_op_type))) except KeyError: pass @@ -1843,8 +1845,8 @@ class Graph(object): obj = conv_fn() # If obj appears to be a name... - if isinstance(obj, six.string_types): - name = obj + if isinstance(obj, compat.bytes_or_text_types): + name = compat.as_str(obj) if ":" in name and allow_tensor: # Looks like a Tensor name and can be a Tensor. @@ -2926,8 +2928,7 @@ def _get_graph_from_inputs(op_input_list, graph=None): Returns: The appropriate graph to use for the given inputs. """ - if not isinstance(op_input_list, (list, tuple)): - raise TypeError("The op_input_list must be a list or tuple") + op_input_list = tuple(op_input_list) # Handle generators correctly # 1. If the graph is specified explicitly, we validate that all of the inputs # are compatible with that graph. diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index d9dbd99b29e..b6dab941020 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -21,11 +21,11 @@ from __future__ import print_function import tensorflow.python.platform from tensorflow.python.framework import device as pydev +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_kernel_label_op from tensorflow.python.framework import test_util -from tensorflow.python.framework import types from tensorflow.python.ops import common_shapes from tensorflow.python.platform import googletest @@ -34,11 +34,11 @@ class TensorTest(test_util.TensorFlowTestCase): def testShape(self): op = ops.Operation(ops._NodeDef("noop", "myop"), ops.Graph(), - [], [types.float32]) + [], [dtypes.float32]) t = op.outputs[0] - self.assertEquals(tensor_shape.unknown_shape(), t.get_shape()) + self.assertEqual(tensor_shape.unknown_shape(), t.get_shape()) t.set_shape([1, 2, 3]) - self.assertEquals([1, 2, 3], t.get_shape()) + self.assertEqual([1, 2, 3], t.get_shape()) class NodeDefConstructorTest(test_util.TensorFlowTestCase): @@ -84,38 +84,38 @@ class OperationTest(test_util.TensorFlowTestCase): def testNoInputs(self): op = ops.Operation(ops._NodeDef("noop", "myop"), ops.Graph(), [], - [types.float32, types.string]) - self.assertEquals(2, len(op.values())) - self.assertEquals(0, len(op.inputs)) - self.assertEquals("myop", op.name) + [dtypes.float32, dtypes.string]) + self.assertEqual(2, len(op.values())) + self.assertEqual(0, len(op.inputs)) + self.assertEqual("myop", op.name) float_t, label_str_t = op.values() - self.assertEquals(types.float32, float_t.dtype) - self.assertEquals(op, float_t.op) - self.assertEquals(0, float_t._value_index) - self.assertEquals(0, len(float_t._consumers)) - self.assertEquals("myop", float_t._as_node_def_input()) + self.assertEqual(dtypes.float32, float_t.dtype) + self.assertEqual(op, float_t.op) + self.assertEqual(0, float_t._value_index) + self.assertEqual(0, len(float_t._consumers)) + self.assertEqual("myop", float_t._as_node_def_input()) - self.assertEquals(types.string, label_str_t.dtype) - self.assertEquals(op, label_str_t.op) - self.assertEquals(1, label_str_t._value_index) - self.assertEquals(0, len(label_str_t._consumers)) - self.assertEquals("myop:1", label_str_t._as_node_def_input()) + self.assertEqual(dtypes.string, label_str_t.dtype) + self.assertEqual(op, label_str_t.op) + self.assertEqual(1, label_str_t._value_index) + self.assertEqual(0, len(label_str_t._consumers)) + self.assertEqual("myop:1", label_str_t._as_node_def_input()) self.assertProtoEquals("op:'noop' name:'myop'", op.node_def) def testNoOutputs(self): g = ops.Graph() op1 = ops.Operation( - ops._NodeDef("noop", "myop1"), g, [], [types.float32]) + ops._NodeDef("noop", "myop1"), g, [], [dtypes.float32]) float_t, = op1.values() op2 = ops.Operation(ops._NodeDef("reop", "myop2"), g, [float_t], []) - self.assertEquals(0, len(op2.values())) - self.assertEquals(1, len(op2.inputs)) + self.assertEqual(0, len(op2.values())) + self.assertEqual(1, len(op2.inputs)) self.assertIs(float_t, op2.inputs[0]) - self.assertEquals(1, len(float_t._consumers)) - self.assertEquals(op2, float_t._consumers[0]) + self.assertEqual(1, len(float_t._consumers)) + self.assertEqual(op2, float_t._consumers[0]) self.assertProtoEquals("op:'noop' name:'myop1'", op1.node_def) self.assertProtoEquals("op:'reop' name:'myop2' input:'myop1'", @@ -124,29 +124,29 @@ class OperationTest(test_util.TensorFlowTestCase): def testInputsAndOutputs(self): g = ops.Graph() op1 = ops.Operation( - ops._NodeDef("noop", "myop1"), g, [], [types.float32]) - self.assertEquals(1, len(op1.values())) + ops._NodeDef("noop", "myop1"), g, [], [dtypes.float32]) + self.assertEqual(1, len(op1.values())) float1_t, = op1.values() op2 = ops.Operation(ops._NodeDef("reop", "myop2"), g, - [], [types.float32, types.string]) - self.assertEquals(2, len(op2.values())) + [], [dtypes.float32, dtypes.string]) + self.assertEqual(2, len(op2.values())) float2_t, label2_str_t = op2.values() # Note that we consume label2_str_t twice here. op3 = ops.Operation(ops._NodeDef("add", "myop3"), g, [float1_t, label2_str_t, label2_str_t], - [types.float32, types.int32]) - self.assertEquals(2, len(op3.values())) + [dtypes.float32, dtypes.int32]) + self.assertEqual(2, len(op3.values())) - self.assertEquals(1, len(float1_t._consumers)) - self.assertEquals(op3, float1_t._consumers[0]) + self.assertEqual(1, len(float1_t._consumers)) + self.assertEqual(op3, float1_t._consumers[0]) - self.assertEquals(0, len(float2_t._consumers)) + self.assertEqual(0, len(float2_t._consumers)) - self.assertEquals(2, len(label2_str_t._consumers)) - self.assertEquals(op3, label2_str_t._consumers[0]) - self.assertEquals(op3, label2_str_t._consumers[1]) + self.assertEqual(2, len(label2_str_t._consumers)) + self.assertEqual(op3, label2_str_t._consumers[0]) + self.assertEqual(op3, label2_str_t._consumers[1]) self.assertProtoEquals(""" op:'add' name:'myop3' @@ -168,14 +168,14 @@ class OperationTest(test_util.TensorFlowTestCase): def testReferenceInput(self): g = ops.Graph() op1 = ops.Operation(ops._NodeDef("noop", "op1"), g, [], - [types.float32_ref, types.float32]) + [dtypes.float32_ref, dtypes.float32]) self.assertProtoEquals("op:'noop' name:'op1'", op1.node_def) ref_t, nonref_t = op1.values() # NOTE(mrry): Must specify input_types to preserve ref-typed input. op2 = ops.Operation( ops._NodeDef("refop", "op2"), g, [ref_t, nonref_t], [], - input_types=[types.float32_ref, types.float32]) + input_types=[dtypes.float32_ref, dtypes.float32]) self.assertProtoEquals("op:'refop' name:'op2' input:'op1' input:'op1:1'", op2.node_def) op3 = ops.Operation( @@ -199,34 +199,34 @@ class OperationTest(test_util.TensorFlowTestCase): pass g = ops.Graph() with self.assertRaises(RuntimeError): - g.create_op("shapeless_op", [], [types.float32]) + g.create_op("shapeless_op", [], [dtypes.float32]) def testNoShapeFunction(self): g = ops.Graph() op = ops.Operation(ops._NodeDef("op", "an_op"), g, - output_types = [types.float32]) - self.assertEquals(tensor_shape.unknown_shape(), - _apply_op(g, "an_op", [], [types.float32]).get_shape()) + output_types = [dtypes.float32]) + self.assertEqual(tensor_shape.unknown_shape(), + _apply_op(g, "an_op", [], [dtypes.float32]).get_shape()) class CreateOpTest(test_util.TensorFlowTestCase): def testNodeDefArgs(self): g = ops.Graph() - op1 = g.create_op("const", [], [types.float32], None, name="myop1") + op1 = g.create_op("const", [], [dtypes.float32], None, name="myop1") with g.device("/device:GPU"): op2 = g.create_op("add", [], - [types.float32, types.string], None, + [dtypes.float32, dtypes.string], None, name="myop2") op3 = g.create_op("foo", [list(op1.values())[0], list(op2.values())[1], list(op2.values())[0]], - [types.float32, types.int32], + [dtypes.float32, dtypes.int32], None, name="myop3") - self.assertEquals(None, op1.device) - self.assertEquals("/device:GPU", op2.device) - self.assertEquals(None, op3.device) + self.assertEqual(None, op1.device) + self.assertEqual("/device:GPU", op2.device) + self.assertEqual(None, op3.device) self.assertProtoEquals("name:'myop1' op:'const'", op1.node_def) self.assertProtoEquals("name:'myop2' op:'add' device:'/device:GPU'", op2.node_def) @@ -237,12 +237,12 @@ class CreateOpTest(test_util.TensorFlowTestCase): def testReferenceInput(self): g = ops.Graph() op1 = g.create_op("noop", [], - [types.float32_ref, types.float32], name="op1") + [dtypes.float32_ref, dtypes.float32], name="op1") self.assertProtoEquals("op:'noop' name:'op1'", op1.node_def) ref_t, nonref_t = op1.values() # NOTE(mrry): Must specify input_types to preserve ref-typed input. op2 = g.create_op("refop", [ref_t, nonref_t], [], - input_types=[types.float32_ref, types.float32], + input_types=[dtypes.float32_ref, dtypes.float32], name="op2") self.assertProtoEquals("op:'refop' name:'op2' input:'op1' input:'op1:1'", op2.node_def) @@ -254,29 +254,29 @@ class CreateOpTest(test_util.TensorFlowTestCase): g = ops.Graph() g.finalize() with self.assertRaises(RuntimeError): - g.create_op("const", [], [types.float32], None, name="myop1") + g.create_op("const", [], [dtypes.float32], None, name="myop1") class ApplyOpTest(test_util.TensorFlowTestCase): def testNodeDefArgs(self): g = ops.Graph() - t1 = _apply_op(g, "const", [], [types.float32], name="myop1") + t1 = _apply_op(g, "const", [], [dtypes.float32], name="myop1") with g.device("/device:GPU"): t2 = _apply_op(g, "add", [], - [types.float32, types.string], + [dtypes.float32, dtypes.string], name="myop2") t3 = _apply_op(g, "foo", [t1, t2[1], t2[0]], - [types.float32, types.int32], name="myop3") + [dtypes.float32, dtypes.int32], name="myop3") self.assertTrue(isinstance(t1, ops.Tensor)) self.assertTrue(isinstance(t2, list)) self.assertTrue(isinstance(t3, list)) self.assertTrue(isinstance(t3[0], ops.Tensor)) - self.assertEquals("myop1", t1._as_node_def_input()) - self.assertEquals("myop2", t2[0]._as_node_def_input()) - self.assertEquals("myop2:1", t2[1]._as_node_def_input()) - self.assertEquals("myop3", t3[0]._as_node_def_input()) + self.assertEqual("myop1", t1._as_node_def_input()) + self.assertEqual("myop2", t2[0]._as_node_def_input()) + self.assertEqual("myop2:1", t2[1]._as_node_def_input()) + self.assertEqual("myop3", t3[0]._as_node_def_input()) # Validate that we got the right ops as well self.assertProtoEquals("name:'myop1' op:'const'", t1.op.node_def) self.assertProtoEquals("name:'myop2' op:'add' device:'/device:GPU'", @@ -288,15 +288,15 @@ class ApplyOpTest(test_util.TensorFlowTestCase): def testReferenceInput(self): g = ops.Graph() ref_t, nonref_t = _apply_op( - g, "noop", [], [types.float32_ref, types.float32], name="op1") + g, "noop", [], [dtypes.float32_ref, dtypes.float32], name="op1") self.assertProtoEquals("op:'noop' name:'op1'", ref_t.op.node_def) # NOTE(mrry): Must specify input_types to preserve ref-typed input. - out_2 = _apply_op(g, "refop", [ref_t, nonref_t], [types.int32], - input_types=[types.float32_ref, types.float32], + out_2 = _apply_op(g, "refop", [ref_t, nonref_t], [dtypes.int32], + input_types=[dtypes.float32_ref, dtypes.float32], name="op2") self.assertProtoEquals("op:'refop' name:'op2' input:'op1' input:'op1:1'", out_2.op.node_def) - out_3 = _apply_op(g, "nonrefop", [ref_t, nonref_t], [types.int32], + out_3 = _apply_op(g, "nonrefop", [ref_t, nonref_t], [dtypes.int32], name="op3") self.assertProtoEquals("op:'nonrefop' name:'op3' input:'op1' input:'op1:1'", out_3.op.node_def) @@ -306,108 +306,108 @@ class NameStackTest(test_util.TensorFlowTestCase): def testBasics(self): g = ops.Graph() - self.assertEquals("foo", g.unique_name("foo")) - self.assertEquals("foo_1", g.unique_name("foo")) - self.assertEquals("foo_2", g.unique_name("foo")) - self.assertEquals("foo_1_1", g.unique_name("foo_1")) - self.assertEquals("foo_1_2", g.unique_name("foo_1")) - self.assertEquals("foo_1_2_1", g.unique_name("foo_1_2")) + self.assertEqual("foo", g.unique_name("foo")) + self.assertEqual("foo_1", g.unique_name("foo")) + self.assertEqual("foo_2", g.unique_name("foo")) + self.assertEqual("foo_1_1", g.unique_name("foo_1")) + self.assertEqual("foo_1_2", g.unique_name("foo_1")) + self.assertEqual("foo_1_2_1", g.unique_name("foo_1_2")) with g.name_scope("bar"): - self.assertEquals("bar/foo", g.unique_name("foo")) - self.assertEquals("bar/foo_1", g.unique_name("foo")) + self.assertEqual("bar/foo", g.unique_name("foo")) + self.assertEqual("bar/foo_1", g.unique_name("foo")) with g.name_scope(None): - self.assertEquals("foo_3", g.unique_name("foo")) + self.assertEqual("foo_3", g.unique_name("foo")) with g.name_scope("baz"): - self.assertEquals("bar/baz/foo", g.unique_name("foo")) - self.assertEquals("bar/baz/foo_1", g.unique_name("foo")) + self.assertEqual("bar/baz/foo", g.unique_name("foo")) + self.assertEqual("bar/baz/foo_1", g.unique_name("foo")) with g.name_scope("baz"): - self.assertEquals("bar/baz_1/foo", g.unique_name("foo")) - self.assertEquals("bar/baz_1/foo_1", g.unique_name("foo")) + self.assertEqual("bar/baz_1/foo", g.unique_name("foo")) + self.assertEqual("bar/baz_1/foo_1", g.unique_name("foo")) with g.name_scope("quux"): - self.assertEquals("quux/foo", g.unique_name("foo")) + self.assertEqual("quux/foo", g.unique_name("foo")) with g.name_scope("bar"): with g.name_scope("baz"): - self.assertEquals("bar_1/baz/foo", g.unique_name("foo")) - self.assertEquals("foo_4", g.unique_name("foo")) - self.assertEquals("bar_2", g.unique_name("bar")) + self.assertEqual("bar_1/baz/foo", g.unique_name("foo")) + self.assertEqual("foo_4", g.unique_name("foo")) + self.assertEqual("bar_2", g.unique_name("bar")) def testOutOfOrderUniqueName(self): g = ops.Graph() - self.assertEquals("foo_2", g.unique_name("foo_2")) - self.assertEquals("foo", g.unique_name("foo")) - self.assertEquals("foo_1", g.unique_name("foo")) - self.assertEquals("foo_3", g.unique_name("foo")) + self.assertEqual("foo_2", g.unique_name("foo_2")) + self.assertEqual("foo", g.unique_name("foo")) + self.assertEqual("foo_1", g.unique_name("foo")) + self.assertEqual("foo_3", g.unique_name("foo")) class NameTest(test_util.TensorFlowTestCase): def testGenerateName(self): g = ops.Graph() - op0 = g.create_op("const", [], [types.float32, types.float32]) - self.assertEquals("const", op0.name) - self.assertEquals("const:0", op0.outputs[0].name) - self.assertEquals("const:1", op0.outputs[1].name) + op0 = g.create_op("const", [], [dtypes.float32, dtypes.float32]) + self.assertEqual("const", op0.name) + self.assertEqual("const:0", op0.outputs[0].name) + self.assertEqual("const:1", op0.outputs[1].name) - op1 = g.create_op("const", [], [types.float32]) - self.assertEquals("const_1", op1.name) - self.assertEquals("const_1:0", op1.outputs[0].name) + op1 = g.create_op("const", [], [dtypes.float32]) + self.assertEqual("const_1", op1.name) + self.assertEqual("const_1:0", op1.outputs[0].name) - op2 = g.create_op("const", [], [types.float32], name="my_op") - self.assertEquals("my_op", op2.name) - self.assertEquals("my_op:0", op2.outputs[0].name) + op2 = g.create_op("const", [], [dtypes.float32], name="my_op") + self.assertEqual("my_op", op2.name) + self.assertEqual("my_op:0", op2.outputs[0].name) def testname_scope(self): g = ops.Graph() with g.name_scope("foo") as foo: - self.assertEquals(foo, "foo/") + self.assertEqual(foo, "foo/") with g.name_scope("foo2") as foo2: - self.assertEquals(foo2, "foo/foo2/") + self.assertEqual(foo2, "foo/foo2/") with g.name_scope(None) as empty1: - self.assertEquals(empty1, "") + self.assertEqual(empty1, "") with g.name_scope("foo3") as foo3: - self.assertEquals(foo3, "foo3/") + self.assertEqual(foo3, "foo3/") with g.name_scope("") as empty2: - self.assertEquals(empty2, "") + self.assertEqual(empty2, "") - self.assertEquals("const", - g.create_op("const", [], [types.float32]).name) + self.assertEqual("const", + g.create_op("const", [], [dtypes.float32]).name) with g.name_scope("bar") as scope: - self.assertEquals("bar/const", - g.create_op("const", [], [types.float32]).name) - self.assertEquals("bar/const_1", - g.create_op("const", [], [types.float32]).name) + self.assertEqual("bar/const", + g.create_op("const", [], [dtypes.float32]).name) + self.assertEqual("bar/const_1", + g.create_op("const", [], [dtypes.float32]).name) # If you use the value from "with .. as", that values is used as-is. - self.assertEquals( + self.assertEqual( "bar", - g.create_op("const", [], [types.float32], name=scope).name) + g.create_op("const", [], [dtypes.float32], name=scope).name) with g.name_scope("baz") as scope: with g.name_scope("quux"): - self.assertEquals("baz/quux/const", - g.create_op("const", [], [types.float32]).name) + self.assertEqual("baz/quux/const", + g.create_op("const", [], [dtypes.float32]).name) # If you use the value from the enclosing "with .. as", nothing is pushed. with g.name_scope(scope): - self.assertEquals("baz/const", - g.create_op("const", [], [types.float32]).name) - self.assertEquals("baz", - g.create_op("const", [], [types.float32], + self.assertEqual("baz/const", + g.create_op("const", [], [dtypes.float32]).name) + self.assertEqual("baz", + g.create_op("const", [], [dtypes.float32], name=scope).name) - self.assertEquals("trailing", - g.create_op("const", [], [types.float32], + self.assertEqual("trailing", + g.create_op("const", [], [dtypes.float32], name="trailing/").name) with g.name_scope("bar"): - self.assertEquals("bar_1/const", - g.create_op("const", [], [types.float32]).name) + self.assertEqual("bar_1/const", + g.create_op("const", [], [dtypes.float32]).name) with g.name_scope("bar/"): - self.assertEquals("bar/const_2", - g.create_op("const", [], [types.float32]).name) + self.assertEqual("bar/const_2", + g.create_op("const", [], [dtypes.float32]).name) class DeviceTest(test_util.TensorFlowTestCase): def testNoDevice(self): g = ops.Graph() - op = g.create_op("an_op", [], [types.float32]) + op = g.create_op("an_op", [], [dtypes.float32]) self.assertEqual(None, op.device) gd = g.as_graph_def() self.assertProtoEquals(""" @@ -417,7 +417,7 @@ class DeviceTest(test_util.TensorFlowTestCase): def testDevicePartialString(self): g = ops.Graph() with g.device("/job:worker/replica:2"): - g.create_op("an_op", [], [types.float32]) + g.create_op("an_op", [], [dtypes.float32]) gd = g.as_graph_def() self.assertProtoEquals(""" node { name: "an_op" op: "an_op" device: "/job:worker/replica:2" } @@ -428,7 +428,7 @@ class DeviceTest(test_util.TensorFlowTestCase): with g.device(pydev.Device(job="worker", replica=2, task=0, device_type="CPU", device_index=3)): - g.create_op("an_op", [], [types.float32]) + g.create_op("an_op", [], [dtypes.float32]) gd = g.as_graph_def() self.assertProtoEquals(""" node { name: "an_op" op: "an_op" @@ -438,10 +438,10 @@ class DeviceTest(test_util.TensorFlowTestCase): def testNesting(self): g = ops.Graph() with g.device("/job:worker/replica:2"): - g.create_op("an_op", [], [types.float32]) + g.create_op("an_op", [], [dtypes.float32]) with g.device("/job:worker/replica:3/task:0"): - g.create_op("an_op", [], [types.float32]) - g.create_op("an_op", [], [types.float32]) + g.create_op("an_op", [], [dtypes.float32]) + g.create_op("an_op", [], [dtypes.float32]) gd = g.as_graph_def() self.assertProtoEquals(""" node { name: "an_op" op: "an_op" @@ -455,10 +455,10 @@ class DeviceTest(test_util.TensorFlowTestCase): def testNestingString(self): g = ops.Graph() with g.device("/job:worker/replica:2"): - g.create_op("an_op", [], [types.float32]) + g.create_op("an_op", [], [dtypes.float32]) with g.device("/job:worker/replica:3/task:0"): - g.create_op("an_op", [], [types.float32]) - g.create_op("an_op", [], [types.float32]) + g.create_op("an_op", [], [dtypes.float32]) + g.create_op("an_op", [], [dtypes.float32]) gd = g.as_graph_def() self.assertProtoEquals(""" node { name: "an_op" op: "an_op" @@ -472,10 +472,10 @@ class DeviceTest(test_util.TensorFlowTestCase): def testNestingOverrideGpuCpu(self): g = ops.Graph() with g.device("/job:worker/replica:2/device:CPU:1"): - g.create_op("an_op", [], [types.float32]) + g.create_op("an_op", [], [dtypes.float32]) with g.device("/job:worker/replica:2/device:GPU:2"): - g.create_op("an_op", [], [types.float32]) - g.create_op("an_op", [], [types.float32]) + g.create_op("an_op", [], [dtypes.float32]) + g.create_op("an_op", [], [dtypes.float32]) gd = g.as_graph_def() self.assertProtoEquals(""" node { name: "an_op" op: "an_op" @@ -490,15 +490,15 @@ class DeviceTest(test_util.TensorFlowTestCase): g = ops.Graph() with g.device(pydev.merge_device("/device:GPU:0")): - g.create_op("an_op", [], [types.float32]) + g.create_op("an_op", [], [dtypes.float32]) with g.device(pydev.merge_device("/job:worker")): - g.create_op("an_op", [], [types.float32]) + g.create_op("an_op", [], [dtypes.float32]) with g.device(pydev.merge_device("/device:CPU:0")): - g.create_op("an_op", [], [types.float32]) + g.create_op("an_op", [], [dtypes.float32]) with g.device(pydev.merge_device("/job:ps")): - g.create_op("an_op", [], [types.float32]) + g.create_op("an_op", [], [dtypes.float32]) with g.device(pydev.merge_device(None)): - g.create_op("an_op", [], [types.float32]) + g.create_op("an_op", [], [dtypes.float32]) gd = g.as_graph_def() self.assertProtoEquals(""" @@ -517,10 +517,10 @@ class DeviceTest(test_util.TensorFlowTestCase): def testNoneClearsDefault(self): g = ops.Graph() with g.device("/job:worker/replica:2/device:CPU:1"): - g.create_op("an_op", [], [types.float32]) + g.create_op("an_op", [], [dtypes.float32]) with g.device(None): - g.create_op("an_op", [], [types.float32]) - g.create_op("an_op", [], [types.float32]) + g.create_op("an_op", [], [dtypes.float32]) + g.create_op("an_op", [], [dtypes.float32]) gd = g.as_graph_def() self.assertProtoEquals(""" node { name: "an_op" op: "an_op" @@ -556,22 +556,22 @@ class CollectionTest(test_util.TensorFlowTestCase): blank2 = ObjectWithName("junk/foo") g.add_to_collection("blah", blank2) - self.assertEquals(["foo"], g.get_collection("other")) - self.assertEquals([12, 34], g.get_collection("key")) - self.assertEquals([], g.get_collection("nothing")) - self.assertEquals([27, blank1, blank2], g.get_collection("blah")) - self.assertEquals([blank1], g.get_collection("blah", "prefix")) + self.assertEqual(["foo"], g.get_collection("other")) + self.assertEqual([12, 34], g.get_collection("key")) + self.assertEqual([], g.get_collection("nothing")) + self.assertEqual([27, blank1, blank2], g.get_collection("blah")) + self.assertEqual([blank1], g.get_collection("blah", "prefix")) def testDefaulGraph(self): with ops.Graph().as_default(): ops.add_to_collection("key", 90) ops.add_to_collection("key", 100) # Collections are ordered. - self.assertEquals([90, 100], ops.get_collection("key")) + self.assertEqual([90, 100], ops.get_collection("key")) def an_op(g): - return _apply_op(g, "an_op", [], [types.float32]) + return _apply_op(g, "an_op", [], [dtypes.float32]) ops.NoGradient("an_op") @@ -600,7 +600,7 @@ class RegistrationTest(test_util.TensorFlowTestCase): x = an_op(g) y = copy_op(x) fn = ops.get_gradient_function(y.op) - self.assertEquals(_CopyGrad, fn) + self.assertEqual(_CopyGrad, fn) def testOverrideGradients(self): g = ops.Graph() @@ -608,7 +608,7 @@ class RegistrationTest(test_util.TensorFlowTestCase): with g.gradient_override_map({"copy": "copy_override"}): y = copy_op(x) fn = ops.get_gradient_function(y.op) - self.assertEquals(_CopyOverrideGrad, fn) + self.assertEqual(_CopyOverrideGrad, fn) def testNonExistentOverride(self): g = ops.Graph() @@ -623,8 +623,8 @@ class ComparisonTest(test_util.TensorFlowTestCase): def testMembershipAllowed(self): g = ops.Graph() - t1 = _apply_op(g, "const", [], [types.float32], name="myop1") - t2 = _apply_op(g, "const", [], [types.float32], name="myop2") + t1 = _apply_op(g, "const", [], [dtypes.float32], name="myop1") + t2 = _apply_op(g, "const", [], [dtypes.float32], name="myop2") self.assertTrue(isinstance(t1, ops.Tensor)) self.assertTrue(isinstance(t2, ops.Tensor)) self.assertTrue(t1 in [t1]) @@ -635,12 +635,12 @@ class ControlDependenciesTest(test_util.TensorFlowTestCase): def testBasic(self): g = ops.Graph() - a = _apply_op(g, "const", [], [types.float32]) - b = _apply_op(g, "const", [], [types.float32]) + a = _apply_op(g, "const", [], [dtypes.float32]) + b = _apply_op(g, "const", [], [dtypes.float32]) with g.control_dependencies([a]): - c = _apply_op(g, "const", [], [types.float32]) - d = _apply_op(g, "identity", [b], [types.float32]) - e = _apply_op(g, "identity", [c], [types.float32]) + c = _apply_op(g, "const", [], [dtypes.float32]) + d = _apply_op(g, "identity", [b], [dtypes.float32]) + e = _apply_op(g, "identity", [c], [dtypes.float32]) self.assertEqual(c.op.control_inputs, [a.op]) self.assertEqual(d.op.control_inputs, [a.op]) @@ -649,7 +649,7 @@ class ControlDependenciesTest(test_util.TensorFlowTestCase): def testBasicWithConversion(self): g = ops.Graph() - a = _apply_op(g, "const", [], [types.float32]) + a = _apply_op(g, "const", [], [dtypes.float32]) class ConvertibleObj(object): @@ -657,25 +657,25 @@ class ControlDependenciesTest(test_util.TensorFlowTestCase): return a with g.control_dependencies([ConvertibleObj()]): - c = _apply_op(g, "const", [], [types.float32]) + c = _apply_op(g, "const", [], [dtypes.float32]) self.assertEqual(c.op.control_inputs, [a.op]) def testNested(self): g = ops.Graph() - a_1 = _apply_op(g, "const", [], [types.float32]) - a_2 = _apply_op(g, "const", [], [types.float32]) - a_3 = _apply_op(g, "const", [], [types.float32]) - a_4 = _apply_op(g, "const", [], [types.float32]) + a_1 = _apply_op(g, "const", [], [dtypes.float32]) + a_2 = _apply_op(g, "const", [], [dtypes.float32]) + a_3 = _apply_op(g, "const", [], [dtypes.float32]) + a_4 = _apply_op(g, "const", [], [dtypes.float32]) with g.control_dependencies([a_1, a_2, a_3, a_4]): - b_1 = _apply_op(g, "const", [], [types.float32]) + b_1 = _apply_op(g, "const", [], [dtypes.float32]) with g.control_dependencies([a_1]): with g.control_dependencies([a_2]): with g.control_dependencies([a_3]): with g.control_dependencies([a_4]): - b_2 = _apply_op(g, "const", [], [types.float32]) + b_2 = _apply_op(g, "const", [], [dtypes.float32]) self.assertItemsEqual( [a_1.op, a_2.op, a_3.op, a_4.op], b_1.op.control_inputs) @@ -692,31 +692,31 @@ class ControlDependenciesTest(test_util.TensorFlowTestCase): # * Nodes d_i are defined as Mul(b_i, c_i) at each scope. # * Nodes e_i are defined as Mul(e_i-1, e_i-1) at each scope i > 1. - a_1 = _apply_op(g, "const", [], [types.float32]) - a_2 = _apply_op(g, "const", [], [types.float32]) - a_3 = _apply_op(g, "const", [], [types.float32]) - a_4 = _apply_op(g, "const", [], [types.float32]) + a_1 = _apply_op(g, "const", [], [dtypes.float32]) + a_2 = _apply_op(g, "const", [], [dtypes.float32]) + a_3 = _apply_op(g, "const", [], [dtypes.float32]) + a_4 = _apply_op(g, "const", [], [dtypes.float32]) with g.control_dependencies([a_1]): - b_1 = _apply_op(g, "mul", [a_3, a_4], [types.float32]) - c_1 = _apply_op(g, "mul", [a_1, b_1], [types.float32]) - d_1 = _apply_op(g, "mul", [b_1, c_1], [types.float32]) - e_1 = _apply_op(g, "const", [], [types.float32]) + b_1 = _apply_op(g, "mul", [a_3, a_4], [dtypes.float32]) + c_1 = _apply_op(g, "mul", [a_1, b_1], [dtypes.float32]) + d_1 = _apply_op(g, "mul", [b_1, c_1], [dtypes.float32]) + e_1 = _apply_op(g, "const", [], [dtypes.float32]) with g.control_dependencies([a_2]): - b_2 = _apply_op(g, "mul", [a_3, a_4], [types.float32]) - c_2 = _apply_op(g, "mul", [a_1, b_1], [types.float32]) - d_2 = _apply_op(g, "mul", [b_2, c_2], [types.float32]) - e_2 = _apply_op(g, "mul", [e_1, e_1], [types.float32]) + b_2 = _apply_op(g, "mul", [a_3, a_4], [dtypes.float32]) + c_2 = _apply_op(g, "mul", [a_1, b_1], [dtypes.float32]) + d_2 = _apply_op(g, "mul", [b_2, c_2], [dtypes.float32]) + e_2 = _apply_op(g, "mul", [e_1, e_1], [dtypes.float32]) with g.control_dependencies([a_3]): - b_3 = _apply_op(g, "mul", [a_3, a_4], [types.float32]) - c_3 = _apply_op(g, "mul", [a_1, b_1], [types.float32]) - d_3 = _apply_op(g, "mul", [b_3, c_3], [types.float32]) - e_3 = _apply_op(g, "mul", [e_2, e_2], [types.float32]) + b_3 = _apply_op(g, "mul", [a_3, a_4], [dtypes.float32]) + c_3 = _apply_op(g, "mul", [a_1, b_1], [dtypes.float32]) + d_3 = _apply_op(g, "mul", [b_3, c_3], [dtypes.float32]) + e_3 = _apply_op(g, "mul", [e_2, e_2], [dtypes.float32]) with g.control_dependencies([a_4]): - b_4 = _apply_op(g, "mul", [a_3, a_4], [types.float32]) - c_4 = _apply_op(g, "mul", [a_1, b_1], [types.float32]) - d_4 = _apply_op(g, "mul", [b_4, c_4], [types.float32]) - e_4 = _apply_op(g, "mul", [e_3, e_3], [types.float32]) + b_4 = _apply_op(g, "mul", [a_3, a_4], [dtypes.float32]) + c_4 = _apply_op(g, "mul", [a_1, b_1], [dtypes.float32]) + d_4 = _apply_op(g, "mul", [b_4, c_4], [dtypes.float32]) + e_4 = _apply_op(g, "mul", [e_3, e_3], [dtypes.float32]) self.assertItemsEqual([a_1.op], b_1.op.control_inputs) self.assertItemsEqual([a_1.op, a_2.op], b_2.op.control_inputs) @@ -740,21 +740,21 @@ class ControlDependenciesTest(test_util.TensorFlowTestCase): def testRepeatedDependency(self): g = ops.Graph() - a = g.create_op("foo", [], [types.float32, types.float32]) + a = g.create_op("foo", [], [dtypes.float32, dtypes.float32]) a_0, a_1 = a.outputs with g.control_dependencies([a_0]): - b = _apply_op(g, "const", [], [types.float32]) + b = _apply_op(g, "const", [], [dtypes.float32]) with g.control_dependencies([a_1]): - c = _apply_op(g, "const", [], [types.float32]) + c = _apply_op(g, "const", [], [dtypes.float32]) self.assertEqual(b.op.control_inputs, [a]) self.assertEqual(c.op.control_inputs, [a]) def testNoControlDependencyWithDataDependency(self): g = ops.Graph() - a = _apply_op(g, "const", [], [types.float32]) + a = _apply_op(g, "const", [], [dtypes.float32]) with g.control_dependencies([a]): - b = _apply_op(g, "identity", [a], [types.float32]) + b = _apply_op(g, "identity", [a], [dtypes.float32]) self.assertEqual(b.op.control_inputs, []) @@ -797,27 +797,27 @@ class GraphTest(test_util.TensorFlowTestCase): pass g = ops.Graph() - a = _apply_op(g, "const", [], [types.float32]) + a = _apply_op(g, "const", [], [dtypes.float32]) self.assertEqual(a, g.as_graph_element(ConvertibleObj())) with self.assertRaises(TypeError): g.as_graph_element(NonConvertibleObj()) def testAssertSameGraph(self): g0 = ops.Graph() - a = g0.create_op("a", [], [types.float32]) - b = g0.create_op("b", [], [types.float32]) + a = g0.create_op("a", [], [dtypes.float32]) + b = g0.create_op("b", [], [dtypes.float32]) ops.assert_same_graph([a, b]) ops.assert_same_graph([a, b], g0) g1 = ops.Graph() - c = g1.create_op("c", [], [types.float32]) + c = g1.create_op("c", [], [dtypes.float32]) self.assertRaises(ValueError, ops.assert_same_graph, [a, b, c]) self.assertRaises(ValueError, ops.assert_same_graph, [c], g0) self.assertRaises(ValueError, ops.assert_same_graph, [a], g1) sparse = ops.SparseTensor( - _apply_op(g0, "const", [], [types.int64]), - _apply_op(g0, "const", [], [types.float32]), - _apply_op(g0, "const", [], [types.int64])) + _apply_op(g0, "const", [], [dtypes.int64]), + _apply_op(g0, "const", [], [dtypes.float32]), + _apply_op(g0, "const", [], [dtypes.int64])) ops.assert_same_graph([sparse, a, b]) ops.assert_same_graph([sparse, a, b], g0) self.assertRaises(ValueError, ops.assert_same_graph, [sparse, a, c]) @@ -830,7 +830,7 @@ class KernelLabelTest(test_util.TensorFlowTestCase): def testNoLabel(self): with self.test_session(): - self.assertAllEqual("My label is: default", + self.assertAllEqual(b"My label is: default", test_kernel_label_op.kernel_label().eval()) def testLabelMap(self): @@ -847,12 +847,12 @@ class KernelLabelTest(test_util.TensorFlowTestCase): # pylint: enable=protected-access default_3 = test_kernel_label_op.kernel_label() - self.assertAllEqual("My label is: default", default_1.eval()) - self.assertAllEqual("My label is: default", default_2.eval()) - self.assertAllEqual("My label is: default", default_3.eval()) - self.assertAllEqual("My label is: overload_1", overload_1_1.eval()) - self.assertAllEqual("My label is: overload_1", overload_1_2.eval()) - self.assertAllEqual("My label is: overload_2", overload_2.eval()) + self.assertAllEqual(b"My label is: default", default_1.eval()) + self.assertAllEqual(b"My label is: default", default_2.eval()) + self.assertAllEqual(b"My label is: default", default_3.eval()) + self.assertAllEqual(b"My label is: overload_1", overload_1_1.eval()) + self.assertAllEqual(b"My label is: overload_1", overload_1_2.eval()) + self.assertAllEqual(b"My label is: overload_2", overload_2.eval()) if __name__ == "__main__": diff --git a/tensorflow/python/framework/registry.py b/tensorflow/python/framework/registry.py index de91cc2511e..c1a10131529 100644 --- a/tensorflow/python/framework/registry.py +++ b/tensorflow/python/framework/registry.py @@ -26,6 +26,7 @@ from __future__ import print_function import traceback from tensorflow.python.platform import logging +from tensorflow.python.util import compat # Registry mechanism below is based on mapreduce.python.mrpython.Register. @@ -45,8 +46,8 @@ class Registry(object): """Registers a Python object "candidate" for the given "name". Args: - candidate: the candidate object to add to the registry. - name: an optional string specifying the registry key for the candidate. + candidate: The candidate object to add to the registry. + name: An optional string specifying the registry key for the candidate. If None, candidate.__name__ will be used. Raises: KeyError: If same name is used twice. @@ -76,6 +77,7 @@ class Registry(object): Raises: LookupError: if "name" has not been registered. """ + name = compat.as_str(name) if name in self._registry: return self._registry[name][_TYPE_TAG] else: diff --git a/tensorflow/python/framework/tensor_shape.py b/tensorflow/python/framework/tensor_shape.py index 83d27de0041..6914db0d34c 100644 --- a/tensorflow/python/framework/tensor_shape.py +++ b/tensorflow/python/framework/tensor_shape.py @@ -51,6 +51,10 @@ class Dimension(object): def __int__(self): return self._value + def __index__(self): + # Allow use in Python 3 range + return self._value + @property def value(self): """The value of this dimension, or None if it is unknown.""" @@ -705,7 +709,7 @@ class TensorShape(object): raise ValueError("Shape %s is not fully defined" % self) def as_dimension_list(self): - """DEPRECATED: use as_list().""" + """DEPRECATED: use `as_list()`.""" self.assert_is_fully_defined() return self.as_list() diff --git a/tensorflow/python/framework/tensor_shape_div_test.py b/tensorflow/python/framework/tensor_shape_div_test.py index 4a0e39ec86b..062d0b916eb 100644 --- a/tensorflow/python/framework/tensor_shape_div_test.py +++ b/tensorflow/python/framework/tensor_shape_div_test.py @@ -20,6 +20,8 @@ from __future__ import print_function import tensorflow.python.platform +import six + from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util from tensorflow.python.platform import googletest @@ -29,10 +31,11 @@ class DimensionDivTest(test_util.TensorFlowTestCase): def testDivSucceeds(self): """Without from __future__ import division, __div__ should work.""" - values = [tensor_shape.Dimension(x) for x in 3, 7, 11, None] - for x in values: - for y in values: - self.assertEqual((x / y).value, (x // y).value) + if six.PY2: # Old division exists only in Python 2 + values = [tensor_shape.Dimension(x) for x in (3, 7, 11, None)] + for x in values: + for y in values: + self.assertEqual((x / y).value, (x // y).value) if __name__ == "__main__": diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py index b904912da7a..7802db473e4 100644 --- a/tensorflow/python/framework/tensor_util.py +++ b/tensorflow/python/framework/tensor_util.py @@ -18,13 +18,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numbers import tensorflow.python.platform import numpy as np import six from tensorflow.core.framework import tensor_pb2 from tensorflow.core.framework import tensor_shape_pb2 +from tensorflow.python.util import compat # TODO(opensource): Add support for pyx_library in the open-source build. # For now, we use the slow versions that fast_tensor_util replaces. @@ -35,8 +35,8 @@ try: except ImportError: _FAST_TENSOR_UTIL_AVAILABLE = False +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import types # pylint: enable=g-import-not-at-top @@ -53,10 +53,11 @@ if _FAST_TENSOR_UTIL_AVAILABLE: np.complex128: fast_tensor_util.AppendComplex128ArrayToTensorProto, np.object: fast_tensor_util.AppendObjectArrayToTensorProto, np.bool: fast_tensor_util.AppendBoolArrayToTensorProto, - types.qint8.as_numpy_dtype: fast_tensor_util.AppendInt8ArrayToTensorProto, - types.quint8.as_numpy_dtype: + dtypes.qint8.as_numpy_dtype: + fast_tensor_util.AppendInt8ArrayToTensorProto, + dtypes.quint8.as_numpy_dtype: fast_tensor_util.AppendUInt8ArrayToTensorProto, - types.qint32.as_numpy_dtype: + dtypes.qint32.as_numpy_dtype: fast_tensor_util.AppendInt32ArrayToTensorProto, # NOTE(touts): Intentionally no way to feed a DT_BFLOAT16. } @@ -80,7 +81,7 @@ else: for v in [x.real, x.imag]]) def SlowAppendObjectArrayToTensorProto(tensor_proto, proto_values): - tensor_proto.string_val.extend([str(x) for x in proto_values]) + tensor_proto.string_val.extend([compat.as_bytes(x) for x in proto_values]) def SlowAppendBoolArrayToTensorProto(tensor_proto, proto_values): tensor_proto.bool_val.extend([np.asscalar(x) for x in proto_values]) @@ -97,9 +98,9 @@ else: np.complex128: SlowAppendComplexArrayToTensorProto, np.object: SlowAppendObjectArrayToTensorProto, np.bool: SlowAppendBoolArrayToTensorProto, - types.qint8.as_numpy_dtype: SlowAppendIntArrayToTensorProto, - types.quint8.as_numpy_dtype: SlowAppendIntArrayToTensorProto, - types.qint32.as_numpy_dtype: SlowAppendIntArrayToTensorProto, + dtypes.qint8.as_numpy_dtype: SlowAppendIntArrayToTensorProto, + dtypes.quint8.as_numpy_dtype: SlowAppendIntArrayToTensorProto, + dtypes.qint32.as_numpy_dtype: SlowAppendIntArrayToTensorProto, # NOTE(touts): Intentionally no way to feed a DT_BFLOAT16. } @@ -170,8 +171,8 @@ def _FlattenToStrings(nested_strings): _TENSOR_CONTENT_TYPES = frozenset([ - types.float32, types.float64, types.int32, types.uint8, types.int16, - types.int8, types.int64 + dtypes.float32, dtypes.float64, dtypes.int32, dtypes.uint8, dtypes.int16, + dtypes.int8, dtypes.int64 ]) @@ -204,25 +205,25 @@ def _NotNone(v): def _FilterInt(v): if isinstance(v, (list, tuple)): return _FirstNotNone([_FilterInt(x) for x in v]) - return None if isinstance(v, numbers.Integral) else _NotNone(v) + return None if isinstance(v, compat.integral_types) else _NotNone(v) def _FilterFloat(v): if isinstance(v, (list, tuple)): return _FirstNotNone([_FilterFloat(x) for x in v]) - return None if isinstance(v, numbers.Real) else _NotNone(v) + return None if isinstance(v, compat.real_types) else _NotNone(v) def _FilterComplex(v): if isinstance(v, (list, tuple)): return _FirstNotNone([_FilterComplex(x) for x in v]) - return None if isinstance(v, numbers.Complex) else _NotNone(v) + return None if isinstance(v, compat.complex_types) else _NotNone(v) def _FilterStr(v): if isinstance(v, (list, tuple)): return _FirstNotNone([_FilterStr(x) for x in v]) - if isinstance(v, (six.string_types, six.binary_type)): + if isinstance(v, compat.bytes_or_text_types): return None else: return _NotNone(v) @@ -241,19 +242,19 @@ def _FilterNotTensor(v): _TF_TO_IS_OK = { - types.float32: _FilterFloat, - types.float64: _FilterFloat, - types.int32: _FilterInt, - types.uint8: _FilterInt, - types.int16: _FilterInt, - types.int8: _FilterInt, - types.string: _FilterStr, - types.complex64: _FilterComplex, - types.int64: _FilterInt, - types.bool: _FilterBool, - types.qint32: _FilterInt, - types.quint8: _FilterInt, - types.qint8: _FilterInt, + dtypes.float32: _FilterFloat, + dtypes.float64: _FilterFloat, + dtypes.int32: _FilterInt, + dtypes.uint8: _FilterInt, + dtypes.int16: _FilterInt, + dtypes.int8: _FilterInt, + dtypes.string: _FilterStr, + dtypes.complex64: _FilterComplex, + dtypes.int64: _FilterInt, + dtypes.bool: _FilterBool, + dtypes.qint32: _FilterInt, + dtypes.quint8: _FilterInt, + dtypes.qint8: _FilterInt, } @@ -264,8 +265,8 @@ def _AssertCompatible(values, dtype): if dtype is None: raise TypeError("List of Tensors when single Tensor expected") else: - raise TypeError("Expected %s, got %s instead." % - (dtype.name, repr(mismatch))) + raise TypeError("Expected %s, got %s of type '%s' instead." % + (dtype.name, repr(mismatch), type(mismatch).__name__)) def make_tensor_proto(values, dtype=None, shape=None): @@ -308,7 +309,7 @@ def make_tensor_proto(values, dtype=None, shape=None): """ if dtype: - dtype = types.as_dtype(dtype) + dtype = dtypes.as_dtype(dtype) # We first convert value to a numpy array or scalar. if isinstance(values, (np.ndarray, np.generic)): @@ -338,13 +339,13 @@ def make_tensor_proto(values, dtype=None, shape=None): # if dtype is provided, it must be compatible with what numpy # conversion says. - numpy_dtype = types.as_dtype(nparray.dtype) + numpy_dtype = dtypes.as_dtype(nparray.dtype) if numpy_dtype is None: raise TypeError("Unrecognized data type: %s" % nparray.dtype) # If dtype was specified and is a quantized type, we convert # numpy_dtype back into the quantized version. - if dtype in [types.qint8, types.quint8, types.qint32]: + if dtype in [dtypes.qint8, dtypes.quint8, dtypes.qint32]: numpy_dtype = dtype if dtype is not None and not dtype.base_dtype == numpy_dtype.base_dtype: @@ -378,9 +379,9 @@ def make_tensor_proto(values, dtype=None, shape=None): # strings. Since values could be a list of strings, or a multi-dimensional # list of lists that might or might not correspond to the given shape, # we flatten it conservatively. - if numpy_dtype == types.string and not isinstance(values, np.ndarray): + if numpy_dtype == dtypes.string and not isinstance(values, np.ndarray): proto_values = _FlattenToStrings(values) - tensor_proto.string_val.extend([str(x) for x in proto_values]) + tensor_proto.string_val.extend([compat.as_bytes(x) for x in proto_values]) return tensor_proto # TensorFlow expects C order (a.k.a., eigen row major). @@ -412,45 +413,45 @@ def MakeNdarray(tensor): """ shape = [d.size for d in tensor.tensor_shape.dim] num_elements = np.prod(shape) - tensor_dtype = types.as_dtype(tensor.dtype) + tensor_dtype = dtypes.as_dtype(tensor.dtype) dtype = tensor_dtype.as_numpy_dtype if tensor.tensor_content: return np.fromstring(tensor.tensor_content, dtype=dtype).reshape(shape) - elif tensor_dtype == types.float32: + elif tensor_dtype == dtypes.float32: if len(tensor.float_val) == 1: return np.repeat(np.array(tensor.float_val[0], dtype=dtype), num_elements).reshape(shape) else: return np.fromiter(tensor.float_val, dtype=dtype).reshape(shape) - elif tensor_dtype == types.float64: + elif tensor_dtype == dtypes.float64: if len(tensor.double_val) == 1: return np.repeat(np.array(tensor.double_val[0], dtype=dtype), num_elements).reshape(shape) else: return np.fromiter(tensor.double_val, dtype=dtype).reshape(shape) - elif tensor_dtype in [types.int32, types.uint8, types.int16, types.int8, - types.qint32, types.quint8, types.qint8, - types.bfloat16]: + elif tensor_dtype in [dtypes.int32, dtypes.uint8, dtypes.int16, dtypes.int8, + dtypes.qint32, dtypes.quint8, dtypes.qint8, + dtypes.bfloat16]: if len(tensor.int_val) == 1: return np.repeat(np.array(tensor.int_val[0], dtype=dtype), num_elements).reshape(shape) else: return np.fromiter(tensor.int_val, dtype=dtype).reshape(shape) - elif tensor_dtype == types.int64: + elif tensor_dtype == dtypes.int64: if len(tensor.int64_val) == 1: return np.repeat(np.array(tensor.int64_val[0], dtype=dtype), num_elements).reshape(shape) else: return np.fromiter(tensor.int64_val, dtype=dtype).reshape(shape) - elif tensor_dtype == types.string: + elif tensor_dtype == dtypes.string: if len(tensor.string_val) == 1: - return np.repeat(np.array(str(tensor.string_val[0]), dtype=dtype), + return np.repeat(np.array(tensor.string_val[0], dtype=dtype), num_elements).reshape(shape) else: - return np.array([str(x) for x in tensor.string_val], + return np.array([x for x in tensor.string_val], dtype=dtype).reshape(shape) - elif tensor_dtype == types.complex64: + elif tensor_dtype == dtypes.complex64: it = iter(tensor.scomplex_val) if len(tensor.scomplex_val) == 2: return np.repeat(np.array(complex(tensor.scomplex_val[0], @@ -459,7 +460,7 @@ def MakeNdarray(tensor): else: return np.array([complex(x[0], x[1]) for x in zip(it, it)], dtype=dtype).reshape(shape) - elif tensor_dtype == types.bool: + elif tensor_dtype == dtypes.bool: if len(tensor.bool_val) == 1: return np.repeat(np.array(tensor.bool_val[0], dtype=dtype), num_elements).reshape(shape) diff --git a/tensorflow/python/framework/tensor_util_test.py b/tensorflow/python/framework/tensor_util_test.py index 6a05d2d54b5..c7e672d460e 100644 --- a/tensorflow/python/framework/tensor_util_test.py +++ b/tensorflow/python/framework/tensor_util_test.py @@ -22,9 +22,9 @@ import tensorflow.python.platform import numpy as np +from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_util from tensorflow.python.framework import test_util -from tensorflow.python.framework import types from tensorflow.python.ops import array_ops from tensorflow.python.ops import constant_op from tensorflow.python.ops import state_ops @@ -56,7 +56,7 @@ class TensorUtilTest(test_util.TensorFlowTestCase): self.assertAllClose(np.array([10.0, 20.0, 30.0], dtype=np.float32), a) def testFloatTyped(self): - t = tensor_util.make_tensor_proto([10.0, 20.0, 30.0], dtype=types.float32) + t = tensor_util.make_tensor_proto([10.0, 20.0, 30.0], dtype=dtypes.float32) self.assertProtoEquals(""" dtype: DT_FLOAT tensor_shape { dim { size: 3 } } @@ -67,7 +67,7 @@ class TensorUtilTest(test_util.TensorFlowTestCase): self.assertAllClose(np.array([10.0, 20.0, 30.0], dtype=np.float32), a) def testFloatTypeCoerce(self): - t = tensor_util.make_tensor_proto([10, 20, 30], dtype=types.float32) + t = tensor_util.make_tensor_proto([10, 20, 30], dtype=dtypes.float32) self.assertProtoEquals(""" dtype: DT_FLOAT tensor_shape { dim { size: 3 } } @@ -79,7 +79,7 @@ class TensorUtilTest(test_util.TensorFlowTestCase): def testFloatTypeCoerceNdarray(self): arr = np.asarray([10, 20, 30], dtype="int") - t = tensor_util.make_tensor_proto(arr, dtype=types.float32) + t = tensor_util.make_tensor_proto(arr, dtype=dtypes.float32) self.assertProtoEquals(""" dtype: DT_FLOAT tensor_shape { dim { size: 3 } } @@ -136,7 +136,7 @@ class TensorUtilTest(test_util.TensorFlowTestCase): def testFloatTypesWithImplicitRepeat(self): for dtype, nptype in [ - (types.float32, np.float32), (types.float64, np.float64)]: + (dtypes.float32, np.float32), (dtypes.float64, np.float64)]: t = tensor_util.make_tensor_proto([10.0], shape=[3, 4], dtype=dtype) a = tensor_util.MakeNdarray(t) self.assertAllClose(np.array([[10.0, 10.0, 10.0, 10.0], @@ -167,10 +167,10 @@ class TensorUtilTest(test_util.TensorFlowTestCase): def testIntTypes(self): for dtype, nptype in [ - (types.int32, np.int32), - (types.uint8, np.uint8), - (types.int16, np.int16), - (types.int8, np.int8)]: + (dtypes.int32, np.int32), + (dtypes.uint8, np.uint8), + (dtypes.int16, np.int16), + (dtypes.int8, np.int8)]: # Test with array. t = tensor_util.make_tensor_proto([10, 20, 30], dtype=dtype) self.assertEquals(dtype, t.dtype) @@ -188,11 +188,11 @@ class TensorUtilTest(test_util.TensorFlowTestCase): def testIntTypesWithImplicitRepeat(self): for dtype, nptype in [ - (types.int64, np.int64), - (types.int32, np.int32), - (types.uint8, np.uint8), - (types.int16, np.int16), - (types.int8, np.int8)]: + (dtypes.int64, np.int64), + (dtypes.int32, np.int32), + (dtypes.uint8, np.uint8), + (dtypes.int16, np.int16), + (dtypes.int8, np.int8)]: t = tensor_util.make_tensor_proto([10], shape=[3, 4], dtype=dtype) a = tensor_util.MakeNdarray(t) self.assertAllEqual(np.array([[10, 10, 10, 10], @@ -200,7 +200,7 @@ class TensorUtilTest(test_util.TensorFlowTestCase): [10, 10, 10, 10]], dtype=nptype), a) def testLong(self): - t = tensor_util.make_tensor_proto(10, dtype=types.int64) + t = tensor_util.make_tensor_proto(10, dtype=dtypes.int64) self.assertProtoEquals(""" dtype: DT_INT64 tensor_shape {} @@ -212,7 +212,7 @@ class TensorUtilTest(test_util.TensorFlowTestCase): def testLongN(self): t = tensor_util.make_tensor_proto([10, 20, 30], shape=[1, 3], - dtype=types.int64) + dtype=dtypes.int64) self.assertProtoEquals(""" dtype: DT_INT64 tensor_shape { dim { size: 1 } dim { size: 3 } } @@ -242,17 +242,15 @@ class TensorUtilTest(test_util.TensorFlowTestCase): """, t) a = tensor_util.MakeNdarray(t) self.assertEquals(np.object, a.dtype) - self.assertEquals(["foo"], a) + self.assertEquals([b"foo"], a) def testStringWithImplicitRepeat(self): t = tensor_util.make_tensor_proto("f", shape=[3, 4]) a = tensor_util.MakeNdarray(t) - self.assertAllEqual(np.array([["f", "f", "f", "f"], - ["f", "f", "f", "f"], - ["f", "f", "f", "f"]], dtype=np.object), a) + self.assertAllEqual(np.array([[b"f"] * 4] * 3, dtype=np.object), a) def testStringN(self): - t = tensor_util.make_tensor_proto(["foo", "bar", "baz"], shape=[1, 3]) + t = tensor_util.make_tensor_proto([b"foo", b"bar", b"baz"], shape=[1, 3]) self.assertProtoEquals(""" dtype: DT_STRING tensor_shape { dim { size: 1 } dim { size: 3 } } @@ -262,10 +260,11 @@ class TensorUtilTest(test_util.TensorFlowTestCase): """, t) a = tensor_util.MakeNdarray(t) self.assertEquals(np.object, a.dtype) - self.assertAllEqual(np.array([["foo", "bar", "baz"]]), a) + self.assertAllEqual(np.array([[b"foo", b"bar", b"baz"]]), a) def testStringNpArray(self): - t = tensor_util.make_tensor_proto(np.array([["a", "ab"], ["abc", "abcd"]])) + t = tensor_util.make_tensor_proto(np.array([[b"a", b"ab"], + [b"abc", b"abcd"]])) self.assertProtoEquals(""" dtype: DT_STRING tensor_shape { dim { size: 2 } dim { size: 2 } } @@ -276,10 +275,10 @@ class TensorUtilTest(test_util.TensorFlowTestCase): """, t) a = tensor_util.MakeNdarray(t) self.assertEquals(np.object, a.dtype) - self.assertAllEqual(np.array([["a", "ab"], ["abc", "abcd"]]), a) + self.assertAllEqual(np.array([[b"a", b"ab"], [b"abc", b"abcd"]]), a) def testComplex(self): - t = tensor_util.make_tensor_proto((1+2j), dtype=types.complex64) + t = tensor_util.make_tensor_proto((1+2j), dtype=dtypes.complex64) self.assertProtoEquals(""" dtype: DT_COMPLEX64 tensor_shape {} @@ -292,7 +291,7 @@ class TensorUtilTest(test_util.TensorFlowTestCase): def testComplexWithImplicitRepeat(self): t = tensor_util.make_tensor_proto((1+1j), shape=[3, 4], - dtype=types.complex64) + dtype=dtypes.complex64) a = tensor_util.MakeNdarray(t) self.assertAllClose(np.array([[(1+1j), (1+1j), (1+1j), (1+1j)], [(1+1j), (1+1j), (1+1j), (1+1j)], @@ -301,7 +300,7 @@ class TensorUtilTest(test_util.TensorFlowTestCase): def testComplexN(self): t = tensor_util.make_tensor_proto([(1+2j), (3+4j), (5+6j)], shape=[1, 3], - dtype=types.complex64) + dtype=dtypes.complex64) self.assertProtoEquals(""" dtype: DT_COMPLEX64 tensor_shape { dim { size: 1 } dim { size: 3 } } @@ -318,7 +317,7 @@ class TensorUtilTest(test_util.TensorFlowTestCase): def testComplexNpArray(self): t = tensor_util.make_tensor_proto( - np.array([[(1+2j), (3+4j)], [(5+6j), (7+8j)]]), dtype=types.complex64) + np.array([[(1+2j), (3+4j)], [(5+6j), (7+8j)]]), dtype=dtypes.complex64) # scomplex_val are real_0, imag_0, real_1, imag_1, ... self.assertProtoEquals(""" dtype: DT_COMPLEX64 @@ -375,7 +374,7 @@ class ConstantValueTest(test_util.TensorFlowTestCase): self.assertAllClose(np_val, tensor_util.ConstantValue(tf_val)) def testUnknown(self): - tf_val = state_ops.variable_op(shape=[3, 4, 7], dtype=types.float32) + tf_val = state_ops.variable_op(shape=[3, 4, 7], dtype=dtypes.float32) self.assertIs(None, tensor_util.ConstantValue(tf_val)) def testShape(self): diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index c8a6fcacbca..91d2aa48fe7 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -22,11 +22,13 @@ from __future__ import print_function import contextlib import math import re +import sys import threading import tensorflow.python.platform import numpy as np +import six from google.protobuf import text_format @@ -454,3 +456,11 @@ class TensorFlowTestCase(googletest.TestCase): if not isinstance(tf_tensor, ops.Tensor): raise TypeError("tf_tensor must be a Tensor") self.assertAllEqual(np_array.shape, tf_tensor.get_shape().as_list()) + + # Fix Python 3 compatibility issues + if six.PY3: + # Silence a deprecation warning + assertRaisesRegexp = googletest.TestCase.assertRaisesRegex + + # assertItemsEqual is assertCountEqual as of 3.2. + assertItemsEqual = googletest.TestCase.assertCountEqual diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py index 4a61c65a27b..69af7640b61 100644 --- a/tensorflow/python/framework/test_util_test.py +++ b/tensorflow/python/framework/test_util_test.py @@ -27,10 +27,10 @@ from six.moves import xrange # pylint: disable=redefined-builtin from google.protobuf import text_format from tensorflow.core.framework import graph_pb2 +from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import test_util -from tensorflow.python.framework import types from tensorflow.python.platform import googletest from tensorflow.python.ops import logging_ops @@ -83,8 +83,7 @@ class TestUtilTest(test_util.TensorFlowTestCase): t.start() with self.assertRaises(self.failureException) as fe: t.join() - self.assertTrue("integer division or modulo by zero" - in fe.exception.message) + self.assertTrue("integer division or modulo by zero" in str(fe.exception)) def testCheckedThreadWithWrongAssertionFails(self): x = 37 @@ -96,7 +95,7 @@ class TestUtilTest(test_util.TensorFlowTestCase): t.start() with self.assertRaises(self.failureException) as fe: t.join() - self.assertTrue("False is not true" in fe.exception.message) + self.assertTrue("False is not true" in str(fe.exception)) def testMultipleThreadsWithOneFailure(self): def err_func(i): diff --git a/tensorflow/python/framework/types_test.py b/tensorflow/python/framework/types_test.py deleted file mode 100644 index 90b11f3f008..00000000000 --- a/tensorflow/python/framework/types_test.py +++ /dev/null @@ -1,205 +0,0 @@ -# Copyright 2015 Google Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -"""Tests for tensorflow.python.framework.importer.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import tensorflow.python.platform - -import numpy as np -import tensorflow as tf - -from tensorflow.core.framework import types_pb2 -from tensorflow.python.framework import test_util -from tensorflow.python.framework import types -from tensorflow.python.platform import googletest - - -class TypesTest(test_util.TensorFlowTestCase): - - def testAllTypesConstructible(self): - for datatype_enum in types_pb2.DataType.values(): - if datatype_enum == types_pb2.DT_INVALID: - continue - self.assertEqual( - datatype_enum, types.DType(datatype_enum).as_datatype_enum) - - def testAllTypesConvertibleToDType(self): - for datatype_enum in types_pb2.DataType.values(): - if datatype_enum == types_pb2.DT_INVALID: - continue - self.assertEqual( - datatype_enum, types.as_dtype(datatype_enum).as_datatype_enum) - - def testAllTypesConvertibleToNumpyDtype(self): - for datatype_enum in types_pb2.DataType.values(): - if datatype_enum == types_pb2.DT_INVALID: - continue - dtype = types.as_dtype(datatype_enum) - numpy_dtype = dtype.as_numpy_dtype - _ = np.empty((1, 1, 1, 1), dtype=numpy_dtype) - if dtype.base_dtype != types.bfloat16: - # NOTE(touts): Intentionally no way to feed a DT_BFLOAT16. - self.assertEqual( - types.as_dtype(datatype_enum).base_dtype, types.as_dtype(numpy_dtype)) - - def testInvalid(self): - with self.assertRaises(TypeError): - types.DType(types_pb2.DT_INVALID) - with self.assertRaises(TypeError): - types.as_dtype(types_pb2.DT_INVALID) - - def testNumpyConversion(self): - self.assertIs(types.float32, types.as_dtype(np.float32)) - self.assertIs(types.float64, types.as_dtype(np.float64)) - self.assertIs(types.int32, types.as_dtype(np.int32)) - self.assertIs(types.int64, types.as_dtype(np.int64)) - self.assertIs(types.uint8, types.as_dtype(np.uint8)) - self.assertIs(types.int16, types.as_dtype(np.int16)) - self.assertIs(types.int8, types.as_dtype(np.int8)) - self.assertIs(types.complex64, types.as_dtype(np.complex64)) - self.assertIs(types.string, types.as_dtype(np.object)) - self.assertIs(types.string, types.as_dtype(np.array(["foo", "bar"]).dtype)) - self.assertIs(types.bool, types.as_dtype(np.bool)) - with self.assertRaises(TypeError): - types.as_dtype(np.dtype([("f1", np.uint), ("f2", np.int32)])) - - def testStringConversion(self): - self.assertIs(types.float32, types.as_dtype("float32")) - self.assertIs(types.float64, types.as_dtype("float64")) - self.assertIs(types.int32, types.as_dtype("int32")) - self.assertIs(types.uint8, types.as_dtype("uint8")) - self.assertIs(types.int16, types.as_dtype("int16")) - self.assertIs(types.int8, types.as_dtype("int8")) - self.assertIs(types.string, types.as_dtype("string")) - self.assertIs(types.complex64, types.as_dtype("complex64")) - self.assertIs(types.int64, types.as_dtype("int64")) - self.assertIs(types.bool, types.as_dtype("bool")) - self.assertIs(types.qint8, types.as_dtype("qint8")) - self.assertIs(types.quint8, types.as_dtype("quint8")) - self.assertIs(types.qint32, types.as_dtype("qint32")) - self.assertIs(types.bfloat16, types.as_dtype("bfloat16")) - self.assertIs(types.float32_ref, types.as_dtype("float32_ref")) - self.assertIs(types.float64_ref, types.as_dtype("float64_ref")) - self.assertIs(types.int32_ref, types.as_dtype("int32_ref")) - self.assertIs(types.uint8_ref, types.as_dtype("uint8_ref")) - self.assertIs(types.int16_ref, types.as_dtype("int16_ref")) - self.assertIs(types.int8_ref, types.as_dtype("int8_ref")) - self.assertIs(types.string_ref, types.as_dtype("string_ref")) - self.assertIs(types.complex64_ref, types.as_dtype("complex64_ref")) - self.assertIs(types.int64_ref, types.as_dtype("int64_ref")) - self.assertIs(types.bool_ref, types.as_dtype("bool_ref")) - self.assertIs(types.qint8_ref, types.as_dtype("qint8_ref")) - self.assertIs(types.quint8_ref, types.as_dtype("quint8_ref")) - self.assertIs(types.qint32_ref, types.as_dtype("qint32_ref")) - self.assertIs(types.bfloat16_ref, types.as_dtype("bfloat16_ref")) - with self.assertRaises(TypeError): - types.as_dtype("not_a_type") - - def testDTypesHaveUniqueNames(self): - dtypes = [] - names = set() - for datatype_enum in types_pb2.DataType.values(): - if datatype_enum == types_pb2.DT_INVALID: - continue - dtype = types.as_dtype(datatype_enum) - dtypes.append(dtype) - names.add(dtype.name) - self.assertEqual(len(dtypes), len(names)) - - def testIsInteger(self): - self.assertEqual(types.as_dtype("int8").is_integer, True) - self.assertEqual(types.as_dtype("int16").is_integer, True) - self.assertEqual(types.as_dtype("int32").is_integer, True) - self.assertEqual(types.as_dtype("int64").is_integer, True) - self.assertEqual(types.as_dtype("uint8").is_integer, True) - self.assertEqual(types.as_dtype("complex64").is_integer, False) - self.assertEqual(types.as_dtype("float").is_integer, False) - self.assertEqual(types.as_dtype("double").is_integer, False) - self.assertEqual(types.as_dtype("string").is_integer, False) - self.assertEqual(types.as_dtype("bool").is_integer, False) - - def testIsFloating(self): - self.assertEqual(types.as_dtype("int8").is_floating, False) - self.assertEqual(types.as_dtype("int16").is_floating, False) - self.assertEqual(types.as_dtype("int32").is_floating, False) - self.assertEqual(types.as_dtype("int64").is_floating, False) - self.assertEqual(types.as_dtype("uint8").is_floating, False) - self.assertEqual(types.as_dtype("complex64").is_floating, False) - self.assertEqual(types.as_dtype("float32").is_floating, True) - self.assertEqual(types.as_dtype("float64").is_floating, True) - self.assertEqual(types.as_dtype("string").is_floating, False) - self.assertEqual(types.as_dtype("bool").is_floating, False) - - def testMinMax(self): - # make sure min/max evaluates for all data types that have min/max - for datatype_enum in types_pb2.DataType.values(): - if datatype_enum == types_pb2.DT_INVALID: - continue - dtype = types.as_dtype(datatype_enum) - numpy_dtype = dtype.as_numpy_dtype - - # ignore types for which there are no minimum/maximum (or we cannot - # compute it, such as for the q* types) - if (dtype.is_quantized or - dtype.base_dtype == types.bool or - dtype.base_dtype == types.string or - dtype.base_dtype == types.complex64): - continue - - print("%s: %s - %s" % (dtype, dtype.min, dtype.max)) - - # check some values that are known - if numpy_dtype == np.bool_: - self.assertEquals(dtype.min, 0) - self.assertEquals(dtype.max, 1) - if numpy_dtype == np.int8: - self.assertEquals(dtype.min, -128) - self.assertEquals(dtype.max, 127) - if numpy_dtype == np.int16: - self.assertEquals(dtype.min, -32768) - self.assertEquals(dtype.max, 32767) - if numpy_dtype == np.int32: - self.assertEquals(dtype.min, -2147483648) - self.assertEquals(dtype.max, 2147483647) - if numpy_dtype == np.int64: - self.assertEquals(dtype.min, -9223372036854775808) - self.assertEquals(dtype.max, 9223372036854775807) - if numpy_dtype == np.uint8: - self.assertEquals(dtype.min, 0) - self.assertEquals(dtype.max, 255) - if numpy_dtype == np.uint16: - self.assertEquals(dtype.min, 0) - self.assertEquals(dtype.max, 4294967295) - if numpy_dtype == np.uint32: - self.assertEquals(dtype.min, 0) - self.assertEquals(dtype.max, 18446744073709551615) - if numpy_dtype in (np.float16, np.float32, np.float64): - self.assertEquals(dtype.min, np.finfo(numpy_dtype).min) - self.assertEquals(dtype.max, np.finfo(numpy_dtype).max) - - def testRepr(self): - for enum, name in types._TYPE_TO_STRING.items(): - dtype = types.DType(enum) - self.assertEquals(repr(dtype), 'tf.' + name) - dtype2 = eval(repr(dtype)) - self.assertEquals(type(dtype2), types.DType) - self.assertEquals(dtype, dtype2) - - -if __name__ == "__main__": - googletest.main() diff --git a/tensorflow/python/kernel_tests/concat_op_test.py b/tensorflow/python/kernel_tests/concat_op_test.py index f6e81dd9635..f96750d4b04 100644 --- a/tensorflow/python/kernel_tests/concat_op_test.py +++ b/tensorflow/python/kernel_tests/concat_op_test.py @@ -294,7 +294,8 @@ class ConcatOpTest(tf.test.TestCase): x0 = np.random.randn(*(shape0 + (n0,) + shape1)) x1 = np.random.randn(*(shape0 + (n1,) + shape1)) correct = np.concatenate([x0, x1], axis=axis) - xs = map(tf.constant, [x0, x1]) + # TODO(irving): Make tf.concat handle map, then drop list(). + xs = list(map(tf.constant, [x0, x1])) c = tf.concat(axis, xs) self.assertAllEqual(c.eval(), correct) # Check gradients diff --git a/tensorflow/python/kernel_tests/constant_op_test.py b/tensorflow/python/kernel_tests/constant_op_test.py index fb3849e9f7d..5c8e1d8cd6e 100644 --- a/tensorflow/python/kernel_tests/constant_op_test.py +++ b/tensorflow/python/kernel_tests/constant_op_test.py @@ -24,6 +24,7 @@ import numpy as np import tensorflow as tf from tensorflow.python.ops import gen_array_ops +from tensorflow.python.util import compat class ConstantTest(tf.test.TestCase): @@ -84,21 +85,21 @@ class ConstantTest(tf.test.TestCase): self._testAll(np.empty((2, 0, 5)).astype(np.complex64)) def testString(self): - self._testCpu(np.array([str(x) for x in np.arange(-15, 15)]).reshape( - [2, 3, 5])) + self._testCpu(np.array([compat.as_bytes(str(x)) + for x in np.arange(-15, 15)]).reshape([2, 3, 5])) self._testCpu(np.empty((2, 0, 5)).astype(np.str_)) def testStringWithNulls(self): with self.test_session(): - val = tf.convert_to_tensor("\0\0\0\0").eval() + val = tf.convert_to_tensor(b"\0\0\0\0").eval() self.assertEqual(len(val), 4) - self.assertEqual(val, "\0\0\0\0") + self.assertEqual(val, b"\0\0\0\0") with self.test_session(): - val = tf.convert_to_tensor("xx\0xx").eval() + val = tf.convert_to_tensor(b"xx\0xx").eval() self.assertEqual(len(val), 5) - self.assertAllEqual(val, "xx\0xx") - nested = [["\0\0\0\0", "xx\0xx"], ["\0_\0_\0_\0", "\0"]] + self.assertAllEqual(val, b"xx\0xx") + nested = [[b"\0\0\0\0", b"xx\0xx"], [b"\0_\0_\0_\0", b"\0"]] with self.test_session(): val = tf.convert_to_tensor(nested).eval() @@ -284,21 +285,21 @@ class ZerosTest(tf.test.TestCase): self.assertEqual(d.get_shape(), [2, 3]) # Test default type for both constant size and dynamic size z = tf.zeros([2, 3]) - self.assertEquals(z.dtype, tf.float32) + self.assertEqual(z.dtype, tf.float32) self.assertEqual([2, 3], z.get_shape()) z = tf.zeros(tf.shape(d)) - self.assertEquals(z.dtype, tf.float32) + self.assertEqual(z.dtype, tf.float32) self.assertEqual([2, 3], z.get_shape()) # Test explicit type control for dtype in [tf.float32, tf.float64, tf.int32, tf.uint8, tf.int16, tf.int8, tf.complex64, tf.int64]: z = tf.zeros([2, 3], dtype=dtype) - self.assertEquals(z.dtype, dtype) - self.assertEquals([2, 3], z.get_shape()) + self.assertEqual(z.dtype, dtype) + self.assertEqual([2, 3], z.get_shape()) z = tf.zeros(tf.shape(d), dtype=dtype) - self.assertEquals(z.dtype, dtype) - self.assertEquals([2, 3], z.get_shape()) + self.assertEqual(z.dtype, dtype) + self.assertEqual([2, 3], z.get_shape()) class ZerosLikeTest(tf.test.TestCase): @@ -314,7 +315,7 @@ class ZerosLikeTest(tf.test.TestCase): # Constructs a tensor of zeros of the same dimensions and type as "d". z_var = tf.zeros_like(d) # Test that the type is correct - self.assertEquals(z_var.dtype, dtype) + self.assertEqual(z_var.dtype, dtype) z_value = z_var.eval() # Test that the value is correct @@ -332,7 +333,7 @@ class ZerosLikeTest(tf.test.TestCase): # Constructs a tensor of zeros of the same dimensions and type as "d". z_var = gen_array_ops._zeros_like(d) # Test that the type is correct - self.assertEquals(z_var.dtype, dtype) + self.assertEqual(z_var.dtype, dtype) z_value = z_var.eval() # Test that the value is correct @@ -369,20 +370,20 @@ class OnesTest(tf.test.TestCase): self.assertEqual(d.get_shape(), [2, 3]) # Test default type for both constant size and dynamic size z = tf.ones([2, 3]) - self.assertEquals(z.dtype, tf.float32) + self.assertEqual(z.dtype, tf.float32) self.assertEqual([2, 3], z.get_shape()) z = tf.ones(tf.shape(d)) - self.assertEquals(z.dtype, tf.float32) + self.assertEqual(z.dtype, tf.float32) self.assertEqual([2, 3], z.get_shape()) # Test explicit type control for dtype in [tf.float32, tf.float64, tf.int32, tf.uint8, tf.int16, tf.int8, tf.complex64, tf.int64]: z = tf.ones([2, 3], dtype=dtype) - self.assertEquals(z.dtype, dtype) + self.assertEqual(z.dtype, dtype) self.assertEqual([2, 3], z.get_shape()) z = tf.ones(tf.shape(d), dtype=dtype) - self.assertEquals(z.dtype, dtype) + self.assertEqual(z.dtype, dtype) self.assertEqual([2, 3], z.get_shape()) @@ -399,7 +400,7 @@ class OnesLikeTest(tf.test.TestCase): # Constructs a tensor of zeros of the same dimensions and type as "d". z_var = tf.ones_like(d) # Test that the type is correct - self.assertEquals(z_var.dtype, dtype) + self.assertEqual(z_var.dtype, dtype) z_value = z_var.eval() # Test that the value is correct @@ -417,7 +418,7 @@ class OnesLikeTest(tf.test.TestCase): # Constructs a tensor of zeros of the same dimensions and type as "d". z_var = tf.ones_like(d) # Test that the type is correct - self.assertEquals(z_var.dtype, dtype) + self.assertEqual(z_var.dtype, dtype) z_value = z_var.eval() # Test that the value is correct @@ -460,7 +461,7 @@ class FillTest(tf.test.TestCase): self._compare([2, 3], np_ans[0][0], np_ans, use_gpu=False) def testFillString(self): - np_ans = np.array([["yolo"] * 3] * 2) + np_ans = np.array([[b"yolo"] * 3] * 2) with self.test_session(use_gpu=False): tf_ans = tf.fill([2, 3], np_ans[0][0], name="fill").eval() self.assertAllEqual(np_ans, tf_ans) @@ -508,7 +509,7 @@ class PlaceholderTest(tf.test.TestCase): p_identity.eval() with self.assertRaisesWithPredicateMatch( - ValueError, lambda e: "Cannot feed value of shape" in e.message): + ValueError, lambda e: "Cannot feed value of shape" in str(e)): p_identity.eval(feed_dict={p: feed_array[:5, :5]}) def testPartialShape(self): @@ -520,7 +521,7 @@ class PlaceholderTest(tf.test.TestCase): feed_array) with self.assertRaisesWithPredicateMatch( - ValueError, lambda e: "Cannot feed value of shape" in e.message): + ValueError, lambda e: "Cannot feed value of shape" in str(e)): p_identity.eval(feed_dict={p: feed_array[:5, :2]}) def testControlDependency(self): diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index 1806c67065c..c9634562549 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -28,6 +28,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow as tf from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gen_data_flow_ops from tensorflow.python.ops import gradients from tensorflow.python.pywrap_tensorflow import StatusNotOK @@ -974,6 +975,30 @@ class ControlFlowTest(tf.test.TestCase): for i in xrange(10): self.assertEqual([i], q.dequeue().eval()) + def testWhileStack_1(self): + with self.test_session(): + s = gen_data_flow_ops._stack(tf.int32, stack_name="foo") + i = tf.constant(0) + + def c(i): + return tf.less(i, 10) + def b(i): + ni = tf.add(i, 1) + ni = control_flow_ops.with_dependencies( + [gen_data_flow_ops._stack_push(s, i)], ni) + return ni + r = control_flow_ops.While(c, b, [i], parallel_iterations=1) + + x = tf.constant(0) + def c1(i, _): + return tf.greater(i, 0) + def b1(i, x): + ni = tf.sub(i, 1) + nx = x + gen_data_flow_ops._stack_pop(s, tf.int32) + return [ni, nx] + _, rx = control_flow_ops.While(c1, b1, [r, x], parallel_iterations=1) + self.assertEqual(45, rx.eval()) + def testFold_1(self): with self.test_session(): elems = tf.constant([1, 2, 3, 4, 5, 6], name="data") diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py index 9c90830847f..4fb2fcafbf4 100644 --- a/tensorflow/python/kernel_tests/cwise_ops_test.py +++ b/tensorflow/python/kernel_tests/cwise_ops_test.py @@ -607,7 +607,7 @@ class BinaryOpTest(tf.test.TestCase): for func in [tf.add, tf.sub, tf.mul, tf.div, _ADD, _SUB, _MUL, _TRUEDIV, _FLOORDIV]: with self.assertRaisesWithPredicateMatch( - ValueError, lambda e: "Incompatible shapes" in e.message): + ValueError, lambda e: "Incompatible shapes" in str(e)): func(tf.convert_to_tensor([10.0, 20.0, 30.0]), tf.convert_to_tensor([[40.0, 50.0], [60.0, 70.0]])) @@ -737,7 +737,7 @@ class ComparisonOpTest(tf.test.TestCase): for t in dtypes: for f in funcs: with self.assertRaisesWithPredicateMatch( - ValueError, lambda e: "Incompatible shapes" in e.message): + ValueError, lambda e: "Incompatible shapes" in str(e)): f(x.astype(t), y.astype(t)) @@ -813,7 +813,7 @@ class LogicalOpTest(tf.test.TestCase): y = np.random.randint(0, 2, 6).astype(np.bool).reshape(3, 2, 1) for f in [tf.logical_and, tf.logical_or, tf.logical_xor]: with self.assertRaisesWithPredicateMatch( - ValueError, lambda e: "Incompatible shapes" in e.message): + ValueError, lambda e: "Incompatible shapes" in str(e)): f(x, y) diff --git a/tensorflow/python/kernel_tests/decode_csv_op_test.py b/tensorflow/python/kernel_tests/decode_csv_op_test.py index 740aca812f1..c049d9749c4 100644 --- a/tensorflow/python/kernel_tests/decode_csv_op_test.py +++ b/tensorflow/python/kernel_tests/decode_csv_op_test.py @@ -81,7 +81,7 @@ class DecodeCSVOpTest(tf.test.TestCase): "record_defaults": [["1"]] } - expected_out = [["1.0", "ab , c", "a\nbc", 'ab"c', " abc "]] + expected_out = [[b"1.0", b"ab , c", b"a\nbc", b'ab"c', b" abc "]] self._test(args, expected_out) @@ -91,7 +91,7 @@ class DecodeCSVOpTest(tf.test.TestCase): "record_defaults": [[1.0], [1], ["aa"]] } - expected_out = [[1.0, 0.2, 3], [4, 5, 6], ["aa", "bb", "cc"]] + expected_out = [[1.0, 0.2, 3], [4, 5, 6], [b"aa", b"bb", b"cc"]] self._test(args, expected_out) @@ -101,7 +101,7 @@ class DecodeCSVOpTest(tf.test.TestCase): "record_defaults": [[1.0], [0], ["a"]] } - expected_out = [[1.0, 0.2, 3.0], [1, 3, 0], ["a", "bcd", "a"]] + expected_out = [[1.0, 0.2, 3.0], [1, 3, 0], [b"a", b"bcd", b"a"]] self._test(args, expected_out) diff --git a/tensorflow/python/kernel_tests/edit_distance_op_test.py b/tensorflow/python/kernel_tests/edit_distance_op_test.py index 7c1a857189a..ca656cdb933 100644 --- a/tensorflow/python/kernel_tests/edit_distance_op_test.py +++ b/tensorflow/python/kernel_tests/edit_distance_op_test.py @@ -26,8 +26,8 @@ import tensorflow as tf def ConstantOf(x): x = np.asarray(x) - # Convert to int64 if it's not a string - if x.dtype.char != "S": x = np.asarray(x, dtype=np.int64) + # Convert to int64 if it's not a string or unicode + if x.dtype.char not in "SU": x = np.asarray(x, dtype=np.int64) return tf.constant(x) @@ -44,8 +44,9 @@ class EditDistanceTest(tf.test.TestCase): with self.test_session(): if expected_err_re is None: # Shape inference figures out the shape from the shape variables + # Explicit tuple() needed since zip returns an iterator in Python 3. expected_shape = [ - max(h, t) for h, t in zip(hypothesis[2], truth[2])[:-1]] + max(h, t) for h, t in tuple(zip(hypothesis[2], truth[2]))[:-1]] self.assertEqual(edit_distance.get_shape(), expected_shape) output = edit_distance.eval() self.assertAllClose(output, expected_output) diff --git a/tensorflow/python/kernel_tests/embedding_ops_test.py b/tensorflow/python/kernel_tests/embedding_ops_test.py index 635c7e6b992..be34a9a6b29 100644 --- a/tensorflow/python/kernel_tests/embedding_ops_test.py +++ b/tensorflow/python/kernel_tests/embedding_ops_test.py @@ -27,6 +27,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow as tf from tensorflow.python.kernel_tests import gradient_checker as gc +from tensorflow.python.util import compat def _AsLong(array): @@ -145,7 +146,7 @@ def _EmbeddingResult(params, id_vals, num_shards, weight_vals=None): for ids, wts in zip(id_vals, weight_vals): val_aggr = None wt_aggr = None - if isinstance(ids, int): + if isinstance(ids, compat.integral_types): ids = [ids] wts = [wts] for i, wt_val in zip(ids, wts): diff --git a/tensorflow/python/kernel_tests/gradient_checker.py b/tensorflow/python/kernel_tests/gradient_checker.py index c816660cb35..69cc811a6ba 100644 --- a/tensorflow/python/kernel_tests/gradient_checker.py +++ b/tensorflow/python/kernel_tests/gradient_checker.py @@ -26,8 +26,8 @@ import tensorflow.python.platform import numpy as np +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import types from tensorflow.python.ops import array_ops from tensorflow.python.ops import constant_op from tensorflow.python.ops import gradients @@ -152,10 +152,10 @@ def _ComputeDxAndDy(x, y, y_shape): def _ComputeGradient(x, x_shape, dx, y, y_shape, dy, x_init_value=None, delta=1e-3): """Computes the theoretical and numerical jacobian.""" - t = types.as_dtype(x.dtype) - allowed_types = [types.float32, types.float64] + t = dtypes.as_dtype(x.dtype) + allowed_types = [dtypes.float32, dtypes.float64] assert t.base_dtype in allowed_types, "Don't support type %s for x" % t.name - t2 = types.as_dtype(y.dtype) + t2 = dtypes.as_dtype(y.dtype) assert t2.base_dtype in allowed_types, "Don't support type %s for y" % t2.name if x_init_value is not None: @@ -164,7 +164,7 @@ def _ComputeGradient(x, x_shape, dx, y, y_shape, dy, x_shape, i_shape) x_data = x_init_value else: - if t == types.float32: + if t == dtypes.float32: dtype = np.float32 else: dtype = np.float64 diff --git a/tensorflow/python/kernel_tests/identity_op_py_test.py b/tensorflow/python/kernel_tests/identity_op_py_test.py index 51682a4f4b9..d79a8c76fe4 100644 --- a/tensorflow/python/kernel_tests/identity_op_py_test.py +++ b/tensorflow/python/kernel_tests/identity_op_py_test.py @@ -40,9 +40,10 @@ class IdentityOpTest(tf.test.TestCase): self.assertAllEqual(np.array([[10, 20, 30], [40, 50, 60]]), value) def testString(self): + source = [b"A", b"b", b"C", b"d", b"E", b"f"] with self.test_session(): - value = tf.identity(["A", "b", "C", "d", "E", "f"]).eval() - self.assertAllEqual(["A", "b", "C", "d", "E", "f"], value) + value = tf.identity(source).eval() + self.assertAllEqual(source, value) def testIdentityShape(self): with self.test_session(): diff --git a/tensorflow/python/kernel_tests/io_ops_test.py b/tensorflow/python/kernel_tests/io_ops_test.py index 543f0cb5804..23466108154 100644 --- a/tensorflow/python/kernel_tests/io_ops_test.py +++ b/tensorflow/python/kernel_tests/io_ops_test.py @@ -26,12 +26,15 @@ import tensorflow.python.platform import tensorflow as tf +from tensorflow.python.util import compat + class IoOpsTest(tf.test.TestCase): def testReadFile(self): cases = ['', 'Some contents', 'Неки садржаји на српском'] for contents in cases: + contents = compat.as_bytes(contents) temp = tempfile.NamedTemporaryFile(prefix='ReadFileTest') open(temp.name, 'wb').write(contents) with self.test_session(): @@ -40,7 +43,8 @@ class IoOpsTest(tf.test.TestCase): self.assertEqual(read.eval(), contents) def _subset(self, files, indices): - return set([files[i].name for i in range(len(files)) if i in indices]) + return set(compat.as_bytes(files[i].name) + for i in range(len(files)) if i in indices) def testMatchingFiles(self): cases = ['ABcDEF.GH', 'ABzDEF.GH', 'ABasdfjklDEF.GH', 'AB3DEF.GH', @@ -50,7 +54,8 @@ class IoOpsTest(tf.test.TestCase): with self.test_session(): # Test exact match without wildcards. for f in files: - self.assertEqual(tf.matching_files(f.name).eval(), f.name) + self.assertEqual(tf.matching_files(f.name).eval(), + compat.as_bytes(f.name)) # We will look for files matching "ABxDEF.GH*" where "x" is some wildcard. pos = files[0].name.find(cases[0]) diff --git a/tensorflow/python/kernel_tests/listdiff_op_test.py b/tensorflow/python/kernel_tests/listdiff_op_test.py index 7e4dd29a32a..5a551585e86 100644 --- a/tensorflow/python/kernel_tests/listdiff_op_test.py +++ b/tensorflow/python/kernel_tests/listdiff_op_test.py @@ -25,6 +25,8 @@ import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow as tf +from tensorflow.python.util import compat + _TYPES = [tf.int32, tf.int64, tf.float32, tf.float64, tf.string] @@ -33,9 +35,9 @@ class ListDiffTest(tf.test.TestCase): def _testListDiff(self, x, y, out, idx): for dtype in _TYPES: if dtype == tf.string: - x = [str(a) for a in x] - y = [str(a) for a in y] - out = [str(a) for a in out] + x = [compat.as_bytes(str(a)) for a in x] + y = [compat.as_bytes(str(a)) for a in y] + out = [compat.as_bytes(str(a)) for a in out] with self.test_session() as sess: x_tensor = tf.convert_to_tensor(x, dtype=dtype) diff --git a/tensorflow/python/kernel_tests/pack_op_test.py b/tensorflow/python/kernel_tests/pack_op_test.py index 113e9d5c742..f9bdadb82b7 100644 --- a/tensorflow/python/kernel_tests/pack_op_test.py +++ b/tensorflow/python/kernel_tests/pack_op_test.py @@ -35,7 +35,8 @@ class PackOpTest(tf.test.TestCase): for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2): data = np.random.randn(*shape) # Convert [data[0], data[1], ...] separately to tensorflow - xs = map(tf.constant, data) + # TODO(irving): Remove list() once we handle maps correctly + xs = list(map(tf.constant, data)) # Pack back into a single tensorflow tensor c = tf.pack(xs) self.assertAllEqual(c.eval(), data) @@ -47,7 +48,8 @@ class PackOpTest(tf.test.TestCase): data = np.random.randn(*shape) shapes = [shape[1:]] * shape[0] with self.test_session(use_gpu=use_gpu): - xs = map(tf.constant, data) + # TODO(irving): Remove list() once we handle maps correctly + xs = list(map(tf.constant, data)) c = tf.pack(xs) err = gradient_checker.ComputeGradientError(xs, shapes, c, shape) self.assertLess(err, 1e-6) diff --git a/tensorflow/python/kernel_tests/parsing_ops_test.py b/tensorflow/python/kernel_tests/parsing_ops_test.py index 07c9206e1db..c6223be886a 100644 --- a/tensorflow/python/kernel_tests/parsing_ops_test.py +++ b/tensorflow/python/kernel_tests/parsing_ops_test.py @@ -57,14 +57,14 @@ def _compare_output_to_expected( tf.logging.info("Comparing key: %s", k) if isinstance(v, tf.SparseTensor): # Three outputs for SparseTensor : indices, values, shape. - tester.assertEqual([k, 3], [k, len(expected_v)]) - tester.assertAllEqual(flat_output[i], expected_v[0]) - tester.assertAllEqual(flat_output[i + 1], expected_v[1]) - tester.assertAllEqual(flat_output[i + 2], expected_v[2]) + tester.assertEqual([k, len(expected_v)], [k, 3]) + tester.assertAllEqual(expected_v[0], flat_output[i]) + tester.assertAllEqual(expected_v[1], flat_output[i + 1]) + tester.assertAllEqual(expected_v[2], flat_output[i + 2]) i += 3 else: # One output for standard Tensor. - tester.assertAllEqual(flat_output[i], expected_v) + tester.assertAllEqual(expected_v, flat_output[i]) i += 1 @@ -86,13 +86,16 @@ class ParseExampleTest(tf.test.TestCase): batch_size = ( serialized.eval().size if isinstance(serialized, tf.Tensor) else np.asarray(serialized).size) - self.assertEqual(len(dense_keys), len(dense_shapes)) - for (k, s) in zip(dense_keys, dense_shapes): - self.assertEqual(tuple(out[k].get_shape().as_list()), (batch_size,) + s) - for k in sparse_keys: - self.assertEqual(tuple(out[k].indices.get_shape().as_list()), (None, 2)) - self.assertEqual(tuple(out[k].values.get_shape().as_list()), (None,)) - self.assertEqual(tuple(out[k].shape.get_shape().as_list()), (2,)) + if dense_shapes: + self.assertEqual(len(dense_keys), len(dense_shapes)) + for (k, s) in zip(dense_keys, dense_shapes): + self.assertEqual( + tuple(out[k].get_shape().as_list()), (batch_size,) + s) + for k in sparse_keys: + self.assertEqual( + tuple(out[k].indices.get_shape().as_list()), (None, 2)) + self.assertEqual(tuple(out[k].values.get_shape().as_list()), (None,)) + self.assertEqual(tuple(out[k].shape.get_shape().as_list()), (2,)) # Check values result = flatten_values_tensors_or_sparse(out.values()) # flatten values @@ -110,7 +113,7 @@ class ParseExampleTest(tf.test.TestCase): dense_types = [tf.int64, tf.string, tf.float32] dense_defaults = { "a": [0, 42, 0], - "b": np.random.rand(3, 3).astype(np.str), + "b": np.random.rand(3, 3).astype(bytes), cname: np.random.rand(2).astype(np.float32), } @@ -128,7 +131,7 @@ class ParseExampleTest(tf.test.TestCase): self._test( { - "names": np.empty((0,), dtype=np.str), + "names": np.empty((0,), dtype=bytes), # empty serialized input Examples "serialized": tf.convert_to_tensor(["", ""]), "dense_defaults": dense_defaults, @@ -143,7 +146,7 @@ class ParseExampleTest(tf.test.TestCase): dense_shapes = [(1, 3), (3, 3), (2,)] dense_defaults = { "a": [0, 42, 0], - "b": np.random.rand(3, 3).astype(np.str), + "b": np.random.rand(3, 3).astype(bytes), # Feature "c" is missing, since there's gaps it will cause failure. } self._test( @@ -188,6 +191,47 @@ class ParseExampleTest(tf.test.TestCase): }, expected_err_re="Name: failing, Key: a. Number of float values") + def testDenseDefaultNoShapeShouldFail(self): + original = [ + example(features=features({ + "a": float_feature([1, 1, 3]), + })), + ] + + serialized = [m.SerializeToString() for m in original] + + self._test( + { + "serialized": tf.convert_to_tensor(serialized), + "names": ["failing"], + "dense_keys": ["a"], + "dense_types": [tf.float32], + }, + expected_err_re="Name: failing, Key: a. Number of float values") + + def testDenseDefaultNoShapeOk(self): + original = [ + example(features=features({ + "a": float_feature([1]), + })), + example(features=features({ + "a": float_feature([1]), + })) + ] + + serialized = [m.SerializeToString() for m in original] + + self._test( + { + "serialized": tf.convert_to_tensor(serialized), + "names": ["passing", "passing"], + "dense_keys": ["a"], + "dense_types": [tf.float32], + }, + { + "a": np.array([1, 1], dtype=np.float32) + }) + def testSerializedContainingSparse(self): original = [ example(features=features({ @@ -201,7 +245,7 @@ class ParseExampleTest(tf.test.TestCase): })), example(features=features({ "st_c": float_feature([1, 2, -1]), - "st_d": bytes_feature(["hi"]) + "st_d": bytes_feature([b"hi"]) })) ] @@ -214,7 +258,7 @@ class ParseExampleTest(tf.test.TestCase): expected_st_d = ( # indices, values, shape np.array([[3, 0]], dtype=np.int64), - np.array(["hi"], dtype=np.str), + np.array(["hi"], dtype=bytes), np.array([4, 1], dtype=np.int64)) # batch == 2, max_elems = 1 expected_output = { @@ -234,11 +278,11 @@ class ParseExampleTest(tf.test.TestCase): original = [ example(features=features({ "a": float_feature([1, 1]), - bname: bytes_feature(["b0_str"]), + bname: bytes_feature([b"b0_str"]), })), example(features=features({ "a": float_feature([-1, -1]), - bname: bytes_feature(["b1"]), + bname: bytes_feature([b"b1"]), })) ] @@ -248,7 +292,7 @@ class ParseExampleTest(tf.test.TestCase): expected_output = { "a": np.array([[1, 1], [-1, -1]], dtype=np.float32).reshape(2, 1, 2, 1), - bname: np.array(["b0_str", "b1"], dtype=np.str).reshape(2, 1, 1, 1, 1), + bname: np.array(["b0_str", "b1"], dtype=bytes).reshape(2, 1, 1, 1, 1), } # No defaults, values required @@ -289,7 +333,7 @@ class ParseExampleTest(tf.test.TestCase): "a": float_feature([1, 1]), })), example(features=features({ - "b": bytes_feature(["b1"]), + "b": bytes_feature([b"b1"]), })) ] @@ -304,7 +348,7 @@ class ParseExampleTest(tf.test.TestCase): expected_output = { "a": np.array([[1, 1], [3, -3]], dtype=np.float32).reshape(2, 1, 2, 1), - "b": np.array(["tmp_str", "b1"], dtype=np.str).reshape(2, 1, 1, 1, 1), + "b": np.array(["tmp_str", "b1"], dtype=bytes).reshape(2, 1, 1, 1, 1), } self._test( @@ -319,7 +363,7 @@ class ParseExampleTest(tf.test.TestCase): def testSerializedContainingSparseAndDenseWithNoDefault(self): dense_defaults = { "a": [1, 2, 3], - "b": np.random.rand(3, 3).astype(np.str), + "b": np.random.rand(3, 3).astype(bytes), # Feature "c" must be provided } dense_shapes = [(1, 3), (3, 3), (2,)] @@ -396,7 +440,7 @@ class ParseSingleExampleTest(tf.test.TestCase): dense_shapes = [(1, 3), (3, 3), (2,)] dense_defaults = { "a": [1, 2, 3], - "b": np.random.rand(3, 3).astype(np.str), + "b": np.random.rand(3, 3).astype(bytes), # Feature "c" must be provided } diff --git a/tensorflow/python/kernel_tests/reader_ops_test.py b/tensorflow/python/kernel_tests/reader_ops_test.py index 6ded5180ff1..94ea9e825f6 100644 --- a/tensorflow/python/kernel_tests/reader_ops_test.py +++ b/tensorflow/python/kernel_tests/reader_ops_test.py @@ -24,6 +24,8 @@ import tensorflow.python.platform import tensorflow as tf +from tensorflow.python.util import compat + class IdentityReaderTest(tf.test.TestCase): @@ -49,12 +51,12 @@ class IdentityReaderTest(tf.test.TestCase): queue.close().run() self.assertAllEqual(3, queued_length.eval()) - self._ExpectRead(sess, key, value, "A") + self._ExpectRead(sess, key, value, b"A") self.assertAllEqual(1, produced.eval()) - self._ExpectRead(sess, key, value, "B") + self._ExpectRead(sess, key, value, b"B") - self._ExpectRead(sess, key, value, "C") + self._ExpectRead(sess, key, value, b"C") self.assertAllEqual(3, produced.eval()) self.assertAllEqual(0, queued_length.eval()) @@ -74,14 +76,14 @@ class IdentityReaderTest(tf.test.TestCase): key, value = reader.read(queue) enqueue.run() - self._ExpectRead(sess, key, value, "DD") - self._ExpectRead(sess, key, value, "EE") + self._ExpectRead(sess, key, value, b"DD") + self._ExpectRead(sess, key, value, b"EE") enqueue.run() - self._ExpectRead(sess, key, value, "DD") - self._ExpectRead(sess, key, value, "EE") + self._ExpectRead(sess, key, value, b"DD") + self._ExpectRead(sess, key, value, b"EE") enqueue.run() - self._ExpectRead(sess, key, value, "DD") - self._ExpectRead(sess, key, value, "EE") + self._ExpectRead(sess, key, value, b"DD") + self._ExpectRead(sess, key, value, b"EE") queue.close().run() with self.assertRaisesOpError("is closed and has insufficient elements " "\\(requested 1, current size 0\\)"): @@ -95,26 +97,26 @@ class IdentityReaderTest(tf.test.TestCase): queue.enqueue_many([["X", "Y", "Z"]]).run() key, value = reader.read(queue) - self._ExpectRead(sess, key, value, "X") + self._ExpectRead(sess, key, value, b"X") self.assertAllEqual(1, produced.eval()) state = reader.serialize_state().eval() - self._ExpectRead(sess, key, value, "Y") - self._ExpectRead(sess, key, value, "Z") + self._ExpectRead(sess, key, value, b"Y") + self._ExpectRead(sess, key, value, b"Z") self.assertAllEqual(3, produced.eval()) queue.enqueue_many([["Y", "Z"]]).run() queue.close().run() reader.restore_state(state).run() self.assertAllEqual(1, produced.eval()) - self._ExpectRead(sess, key, value, "Y") - self._ExpectRead(sess, key, value, "Z") + self._ExpectRead(sess, key, value, b"Y") + self._ExpectRead(sess, key, value, b"Z") with self.assertRaisesOpError("is closed and has insufficient elements " "\\(requested 1, current size 0\\)"): sess.run([key, value]) self.assertAllEqual(3, produced.eval()) - self.assertEqual(str, type(state)) + self.assertEqual(bytes, type(state)) with self.assertRaises(ValueError): reader.restore_state([]) @@ -132,15 +134,15 @@ class IdentityReaderTest(tf.test.TestCase): with self.assertRaisesOpError( "Could not parse state for IdentityReader 'test_reader'"): - reader.restore_state(state + "ExtraJunk").run() + reader.restore_state(state + b"ExtraJunk").run() with self.assertRaisesOpError( "Could not parse state for IdentityReader 'test_reader'"): - reader.restore_state("PREFIX" + state).run() + reader.restore_state(b"PREFIX" + state).run() with self.assertRaisesOpError( "Could not parse state for IdentityReader 'test_reader'"): - reader.restore_state("BOGUS" + state[5:]).run() + reader.restore_state(b"BOGUS" + state[5:]).run() def testReset(self): with self.test_session() as sess: @@ -152,11 +154,11 @@ class IdentityReaderTest(tf.test.TestCase): key, value = reader.read(queue) queue.enqueue_many([["X", "Y", "Z"]]).run() - self._ExpectRead(sess, key, value, "X") + self._ExpectRead(sess, key, value, b"X") self.assertLess(0, queued_length.eval()) self.assertAllEqual(1, produced.eval()) - self._ExpectRead(sess, key, value, "Y") + self._ExpectRead(sess, key, value, b"Y") self.assertLess(0, work_completed.eval()) self.assertAllEqual(2, produced.eval()) @@ -164,10 +166,10 @@ class IdentityReaderTest(tf.test.TestCase): self.assertAllEqual(0, work_completed.eval()) self.assertAllEqual(0, produced.eval()) self.assertAllEqual(1, queued_length.eval()) - self._ExpectRead(sess, key, value, "Z") + self._ExpectRead(sess, key, value, b"Z") queue.enqueue_many([["K", "L"]]).run() - self._ExpectRead(sess, key, value, "K") + self._ExpectRead(sess, key, value, b"K") class WholeFileReaderTest(tf.test.TestCase): @@ -176,9 +178,9 @@ class WholeFileReaderTest(tf.test.TestCase): super(WholeFileReaderTest, self).setUp() self._filenames = [os.path.join(self.get_temp_dir(), "whole_file.%d.txt" % i) for i in range(3)] - self._content = ["One\na\nb\n", "Two\nC\nD", "Three x, y, z"] + self._content = [b"One\na\nb\n", b"Two\nC\nD", b"Three x, y, z"] for fn, c in zip(self._filenames, self._content): - open(fn, "w").write(c) + open(fn, "wb").write(c) def tearDown(self): super(WholeFileReaderTest, self).tearDown() @@ -187,7 +189,7 @@ class WholeFileReaderTest(tf.test.TestCase): def _ExpectRead(self, sess, key, value, index): k, v = sess.run([key, value]) - self.assertAllEqual(self._filenames[index], k) + self.assertAllEqual(compat.as_bytes(self._filenames[index]), k) self.assertAllEqual(self._content[index], v) def testOneEpoch(self): @@ -233,20 +235,20 @@ class TextLineReaderTest(tf.test.TestCase): self._num_lines = 5 def _LineText(self, f, l): - return "%d: %d" % (f, l) + return compat.as_bytes("%d: %d" % (f, l)) def _CreateFiles(self): filenames = [] for i in range(self._num_files): fn = os.path.join(self.get_temp_dir(), "text_line.%d.txt" % i) filenames.append(fn) - f = open(fn, "w") + f = open(fn, "wb") for j in range(self._num_lines): f.write(self._LineText(i, j)) # Always include a newline after the record unless it is # at the end of the file, in which case we include it sometimes. if j + 1 != self._num_lines or i == 0: - f.write("\n") + f.write(b"\n") return filenames def testOneEpoch(self): @@ -261,7 +263,7 @@ class TextLineReaderTest(tf.test.TestCase): for i in range(self._num_files): for j in range(self._num_lines): k, v = sess.run([key, value]) - self.assertAllEqual("%s:%d" % (files[i], j + 1), k) + self.assertAllEqual("%s:%d" % (files[i], j + 1), compat.as_text(k)) self.assertAllEqual(self._LineText(i, j), v) with self.assertRaisesOpError("is closed and has insufficient elements " @@ -280,7 +282,7 @@ class TextLineReaderTest(tf.test.TestCase): for i in range(self._num_files): for j in range(self._num_lines - 1): k, v = sess.run([key, value]) - self.assertAllEqual("%s:%d" % (files[i], j + 2), k) + self.assertAllEqual("%s:%d" % (files[i], j + 2), compat.as_text(k)) self.assertAllEqual(self._LineText(i, j + 1), v) with self.assertRaisesOpError("is closed and has insufficient elements " @@ -299,18 +301,18 @@ class FixedLengthRecordReaderTest(tf.test.TestCase): self._footer_bytes = 2 def _Record(self, f, r): - return str(f * 2 + r) * self._record_bytes + return compat.as_bytes(str(f * 2 + r) * self._record_bytes) def _CreateFiles(self): filenames = [] for i in range(self._num_files): fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i) filenames.append(fn) - f = open(fn, "w") - f.write("H" * self._header_bytes) + f = open(fn, "wb") + f.write(b"H" * self._header_bytes) for j in range(self._num_records): f.write(self._Record(i, j)) - f.write("F" * self._footer_bytes) + f.write(b"F" * self._footer_bytes) return filenames def testOneEpoch(self): @@ -329,7 +331,7 @@ class FixedLengthRecordReaderTest(tf.test.TestCase): for i in range(self._num_files): for j in range(self._num_records): k, v = sess.run([key, value]) - self.assertAllEqual("%s:%d" % (files[i], j), k) + self.assertAllEqual("%s:%d" % (files[i], j), compat.as_text(k)) self.assertAllEqual(self._Record(i, j), v) with self.assertRaisesOpError("is closed and has insufficient elements " @@ -345,7 +347,7 @@ class TFRecordReaderTest(tf.test.TestCase): self._num_records = 7 def _Record(self, f, r): - return "Record %d of file %d" % (r, f) + return compat.as_bytes("Record %d of file %d" % (r, f)) def _CreateFiles(self): filenames = [] @@ -369,7 +371,7 @@ class TFRecordReaderTest(tf.test.TestCase): for i in range(self._num_files): for j in range(self._num_records): k, v = sess.run([key, value]) - self.assertTrue(k.startswith("%s:" % files[i])) + self.assertTrue(compat.as_text(k).startswith("%s:" % files[i])) self.assertAllEqual(self._Record(i, j), v) with self.assertRaisesOpError("is closed and has insufficient elements " diff --git a/tensorflow/python/kernel_tests/reduction_ops_test.py b/tensorflow/python/kernel_tests/reduction_ops_test.py index abf5f7a0d25..afb437ea3c2 100644 --- a/tensorflow/python/kernel_tests/reduction_ops_test.py +++ b/tensorflow/python/kernel_tests/reduction_ops_test.py @@ -131,13 +131,13 @@ class SumReductionTest(tf.test.TestCase): np_arr = np.arange(0, 10).reshape([2, 5]).astype(np.float32) input_tensor = tf.convert_to_tensor(np_arr) with self.assertRaisesWithPredicateMatch( - ValueError, lambda e: "Invalid reduction dimension" in e.message): + ValueError, lambda e: "Invalid reduction dimension" in str(e)): tf.reduce_sum(input_tensor, [-1]) with self.assertRaisesWithPredicateMatch( - ValueError, lambda e: "Invalid reduction dimension" in e.message): + ValueError, lambda e: "Invalid reduction dimension" in str(e)): tf.reduce_sum(input_tensor, [2]) with self.assertRaisesWithPredicateMatch( - ValueError, lambda e: "Invalid reduction dimension" in e.message): + ValueError, lambda e: "Invalid reduction dimension" in str(e)): tf.reduce_sum(input_tensor, [0, 2]) # Int64?? diff --git a/tensorflow/python/kernel_tests/save_restore_ops_test.py b/tensorflow/python/kernel_tests/save_restore_ops_test.py index a678f74b6a2..2f898522892 100644 --- a/tensorflow/python/kernel_tests/save_restore_ops_test.py +++ b/tensorflow/python/kernel_tests/save_restore_ops_test.py @@ -31,9 +31,9 @@ class ShardedFileOpsTest(tf.test.TestCase): target="", config=tf.ConfigProto(device_count={"CPU": 2})): self.assertEqual(gen_io_ops._sharded_filename("foo", 4, 100).eval(), - "foo-00004-of-00100") + b"foo-00004-of-00100") self.assertEqual(gen_io_ops._sharded_filespec("foo", 100).eval(), - "foo-?????-of-00100") + b"foo-?????-of-00100") if __name__ == "__main__": diff --git a/tensorflow/python/kernel_tests/shape_ops_test.py b/tensorflow/python/kernel_tests/shape_ops_test.py index 1023ff02c9a..b2ff0b92b43 100644 --- a/tensorflow/python/kernel_tests/shape_ops_test.py +++ b/tensorflow/python/kernel_tests/shape_ops_test.py @@ -245,12 +245,10 @@ class TileTest(tf.test.TestCase): "uint8": (tf.uint8, int), "int32": (tf.int32, int), "int64": (tf.int64, int), - "string": (tf.string, str) + bytes: (tf.string, bytes) } - for dtype_np, v in types_to_test.items(): + for dtype_np, (dtype_tf, cast) in types_to_test.items(): with self.test_session(): - dtype_tf = v[0] - cast = v[1] inp = np.random.rand(4, 1).astype(dtype_np) a = tf.constant([cast(x) for x in inp.ravel(order="C")], shape=[4, 1], @@ -259,7 +257,7 @@ class TileTest(tf.test.TestCase): result = tiled.eval() self.assertEqual(result.shape, (4, 4)) self.assertEqual([4, 4], tiled.get_shape()) - self.assertTrue((result == np.tile(inp, (1, 4))).all()) + self.assertAllEqual(result, np.tile(inp, (1, 4))) def testInvalidDim(self): with self.test_session(): diff --git a/tensorflow/python/kernel_tests/slice_op_test.py b/tensorflow/python/kernel_tests/slice_op_test.py index ea3a0cd3bfb..c997ebf5475 100644 --- a/tensorflow/python/kernel_tests/slice_op_test.py +++ b/tensorflow/python/kernel_tests/slice_op_test.py @@ -247,7 +247,7 @@ class SliceTest(tf.test.TestCase): c = tf.constant(5.0) with self.assertRaisesWithPredicateMatch( TypeError, - lambda e: "'Tensor' object is not iterable" in e.message): + lambda e: "'Tensor' object is not iterable" in str(e)): for _ in c: pass diff --git a/tensorflow/python/kernel_tests/sparse_concat_op_test.py b/tensorflow/python/kernel_tests/sparse_concat_op_test.py index 34eeabc44e6..69c71798107 100644 --- a/tensorflow/python/kernel_tests/sparse_concat_op_test.py +++ b/tensorflow/python/kernel_tests/sparse_concat_op_test.py @@ -236,7 +236,7 @@ class SparseConcatTest(tf.test.TestCase): concat_out.indices, [[0, 2], [1, 0], [1, 4], [2, 0], [2, 2], [2, 3], [2, 6], [2, 7]]) self.assertAllEqual( - concat_out.values, ["a", "b", "e", "c", "d", "f", "g", "h"]) + concat_out.values, [b"a", b"b", b"e", b"c", b"d", b"f", b"g", b"h"]) self.assertAllEqual(concat_out.shape, [3, 8]) def testMismatchedRank(self): diff --git a/tensorflow/python/kernel_tests/sparse_matmul_op_test.py b/tensorflow/python/kernel_tests/sparse_matmul_op_test.py index 3e5e059327c..8f0c60c4553 100644 --- a/tensorflow/python/kernel_tests/sparse_matmul_op_test.py +++ b/tensorflow/python/kernel_tests/sparse_matmul_op_test.py @@ -46,11 +46,11 @@ class SparseMatMulTest(tf.test.TestCase): np_ans = x_mat * y_mat with self.test_session(use_gpu=False): tf_ans = tf.matmul(x, y, - transpose_a=tr_a, transpose_b=tr_b, - a_is_sparse=sp_a, - b_is_sparse=sp_b) + transpose_a=tr_a, transpose_b=tr_b, + a_is_sparse=sp_a, + b_is_sparse=sp_b) out = tf_ans.eval() - self.assertAllClose(np_ans, out) + self.assertAllClose(np_ans, out, rtol=1e-4, atol=1e-4) self.assertShapeEqual(np_ans, tf_ans) def testFloatBasic(self): @@ -58,7 +58,20 @@ class SparseMatMulTest(tf.test.TestCase): y = np.arange(-1., 1.).reshape([1, 2]).astype(np.float32) self._testCpuMatmul(x, y) - # Tests testing random sized matrices. + # Tests setting one dimension to be a high value. + def testFloatLarge(self): + r1 = np.random.randint(6000, 20000) + r2 = np.random.randint(1, 10) + r3 = np.random.randint(1, 10) + for m, k, n in [(r1, r2, r3), + (r2, r1, r3), + (r2, r3, r1)]: + x = RandMatrix(m, k, False) + y = RandMatrix(k, n, False) + self._testCpuMatmul(x, y) + self._testCpuMatmul(x, y, sp_a=False, sp_b=True) + + # Tests random sized matrices. def testFloatRandom(self): for _ in range(10): for tr_a in [True, False]: @@ -78,11 +91,11 @@ class MatMulGradientTest(tf.test.TestCase): a = tf.constant(RandMatrix(3, 2, tr_a), dtype=tf.float32) b = tf.constant(RandMatrix(2, 4, tr_b), dtype=tf.float32) m = tf.matmul(a, b, - name=name, - transpose_a=tr_a, - transpose_b=tr_b, - a_is_sparse=sp_a, - b_is_sparse=sp_b) + name=name, + transpose_a=tr_a, + transpose_b=tr_b, + a_is_sparse=sp_a, + b_is_sparse=sp_b) err = (gc.ComputeGradientError(a, [2, 3] if tr_a else [3, 2], m, [3, 4]) + gc.ComputeGradientError(b, [4, 2] if tr_b else [2, 4], m, [3, 4])) print("sparse_matmul gradient err = ", err) diff --git a/tensorflow/python/kernel_tests/stack_ops_test.py b/tensorflow/python/kernel_tests/stack_ops_test.py new file mode 100644 index 00000000000..9f5ed5a101b --- /dev/null +++ b/tensorflow/python/kernel_tests/stack_ops_test.py @@ -0,0 +1,97 @@ +# Copyright 2015 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for tensorflow.ops.stack_ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow.python.platform + +import tensorflow as tf + +from tensorflow.python.framework import errors +from tensorflow.python.ops import gen_data_flow_ops + + +class StackOpTest(tf.test.TestCase): + + def _testStackPushPop(self, use_gpu): + with self.test_session(use_gpu=use_gpu): + h = gen_data_flow_ops._stack(tf.float32, stack_name="foo") + c = gen_data_flow_ops._stack_push(h, [[4.0, 5.0]]) + with tf.control_dependencies([c]): + c1 = gen_data_flow_ops._stack_pop(h, tf.float32) + self.assertAllClose([[4.0, 5.0]], c1.eval()) + + def testStackPushPop(self): + self._testStackPushPop(use_gpu=False) + self._testStackPushPop(use_gpu=True) + + def _testMultiStack(self, use_gpu): + with self.test_session(use_gpu=use_gpu): + h1 = gen_data_flow_ops._stack(tf.float32, stack_name="foo") + c1 = gen_data_flow_ops._stack_push(h1, 4.0) + with tf.control_dependencies([c1]): + c1 = gen_data_flow_ops._stack_pop(h1, tf.float32) + h2 = gen_data_flow_ops._stack(tf.float32, stack_name="bar") + c2 = gen_data_flow_ops._stack_push(h2, 5.0) + with tf.control_dependencies([c2]): + c2 = gen_data_flow_ops._stack_pop(h2, tf.float32) + r = c1 + c2 + self.assertAllClose(9.0, r.eval()) + + def testMultiStack(self): + self._testMultiStack(use_gpu=False) + self._testMultiStack(use_gpu=True) + + def _testDuplicateStack(self, use_gpu): + with self.test_session(use_gpu=use_gpu): + h1 = gen_data_flow_ops._stack(tf.float32, stack_name="foo") + c1 = gen_data_flow_ops._stack_push(h1, 4.0) + h2 = gen_data_flow_ops._stack(tf.float32, stack_name="foo") + c2 = gen_data_flow_ops._stack_push(h2, 5.0) + r = c1 + c2 + with self.assertRaises(errors.AlreadyExistsError): + r.eval() + + def testDuplicateStack(self): + self._testDuplicateStack(use_gpu=False) + self._testDuplicateStack(use_gpu=True) + + def _testCloseStack(self, use_gpu): + with self.test_session(use_gpu=use_gpu) as sess: + h = gen_data_flow_ops._stack(tf.float32, stack_name="foo") + c1 = gen_data_flow_ops._stack_close(h) + sess.run(c1) + + def testCloseStack(self): + self._testCloseStack(use_gpu=False) + self._testCloseStack(use_gpu=True) + + def _testPushCloseStack(self, use_gpu): + with self.test_session(use_gpu=use_gpu) as sess: + h = gen_data_flow_ops._stack(tf.float32, stack_name="foo") + c = gen_data_flow_ops._stack_push(h, [[4.0, 5.0]]) + with tf.control_dependencies([c]): + c1 = gen_data_flow_ops._stack_close(h) + sess.run(c1) + + def testPushCloseStack(self): + self._testPushCloseStack(use_gpu=False) + self._testPushCloseStack(use_gpu=True) + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/python/kernel_tests/summary_ops_test.py b/tensorflow/python/kernel_tests/summary_ops_test.py index 169e766991e..4f7129419cb 100644 --- a/tensorflow/python/kernel_tests/summary_ops_test.py +++ b/tensorflow/python/kernel_tests/summary_ops_test.py @@ -70,7 +70,7 @@ class SummaryOpsTest(tf.test.TestCase): sum_squares: 100.0 bucket_limit: 9.93809490288 bucket_limit: 10.9319043932 - bucket_limit: 1.79769313486e+308 + bucket_limit: 1.7976931348623157e+308 bucket: 0.0 bucket: 1.0 bucket: 0.0 diff --git a/tensorflow/python/kernel_tests/transpose_op_test.py b/tensorflow/python/kernel_tests/transpose_op_test.py index 1fa3690e6fb..8c5ff7bd7e6 100644 --- a/tensorflow/python/kernel_tests/transpose_op_test.py +++ b/tensorflow/python/kernel_tests/transpose_op_test.py @@ -183,7 +183,7 @@ class TransposeTest(tf.test.TestCase): with self.assertRaises(ValueError): tf.transpose(np.arange(0., 30).reshape([2, 3, 5]), [[0, 1], [2, 3]]) self._testError(np.arange(0., 2 ** 10).reshape([2] * 10), - range(10), + np.arange(10), "not implemented") with self.assertRaises(IndexError): tf.transpose(np.arange(0., 30).reshape([2, 3, 5]), [0, 1, 3]) diff --git a/tensorflow/python/kernel_tests/variable_scope_test.py b/tensorflow/python/kernel_tests/variable_scope_test.py index d3e085f274c..e913fc1f501 100644 --- a/tensorflow/python/kernel_tests/variable_scope_test.py +++ b/tensorflow/python/kernel_tests/variable_scope_test.py @@ -144,7 +144,7 @@ class VariableStoreTest(tf.test.TestCase): with self.assertRaises(ValueError) as exc: with variable_scope.variable_scope("towerA"): va2 = variable_scope.get_variable("v", [1]) - self.assertEqual(exc.exception.message[:12], "Over-sharing") + self.assertEqual(str(exc.exception)[:12], "Over-sharing") with variable_scope.variable_scope("towerA", reuse=True): va2 = variable_scope.get_variable("v", [1]) @@ -162,17 +162,17 @@ class VariableStoreTest(tf.test.TestCase): with variable_scope.variable_scope(tower_a, reuse=True): with variable_scope.variable_scope("baz"): variable_scope.get_variable("v", [1]) - self.assertEqual(exc.exception.message[:13], "Under-sharing") + self.assertEqual(str(exc.exception)[:13], "Under-sharing") with self.assertRaises(ValueError) as exc: with variable_scope.variable_scope(tower_a, reuse=True): variable_scope.get_variable("v", [2]) # Different shape. - self.assertEqual("shape" in exc.exception.message, True) + self.assertEqual("shape" in str(exc.exception), True) with self.assertRaises(ValueError) as exc: with variable_scope.variable_scope(tower_a, reuse=True): variable_scope.get_variable("v", [1], dtype=tf.int32) - self.assertEqual("dtype" in exc.exception.message, True) + self.assertEqual("dtype" in str(exc.exception), True) if __name__ == "__main__": diff --git a/tensorflow/python/lib/core/pywrap_status_test.py b/tensorflow/python/lib/core/pywrap_status_test.py index 1b1ea29ba0f..fc0a4e5104e 100644 --- a/tensorflow/python/lib/core/pywrap_status_test.py +++ b/tensorflow/python/lib/core/pywrap_status_test.py @@ -26,21 +26,6 @@ from tensorflow.python.platform import googletest class StatusTest(googletest.TestCase): - def testDefaultOk(self): - status = pywrap_tensorflow.Status() - self.assertTrue(status.ok()) - - def testCodeAndMessage(self): - status = pywrap_tensorflow.Status(error_codes_pb2.INVALID_ARGUMENT, 'foo') - self.assertEqual(error_codes_pb2.INVALID_ARGUMENT, status.code()) - self.assertEqual('foo', status.error_message()) - - def testToString(self): - status = pywrap_tensorflow.Status() - # .ToString was remapped in the .swig file, hence will not work - # self.assertIn('OK', status.ToString()) - self.assertIn('OK', str(status)) - def testException(self): with self.assertRaises(pywrap_tensorflow.StatusNotOK) as context: pywrap_tensorflow.NotOkay() diff --git a/tensorflow/python/lib/core/status.i b/tensorflow/python/lib/core/status.i index fddbc31e249..f73b3b5b89e 100644 --- a/tensorflow/python/lib/core/status.i +++ b/tensorflow/python/lib/core/status.i @@ -1,3 +1,18 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + // SWIG wrapper for lib::tensorflow::Status %include "tensorflow/python/platform/base.i" @@ -69,13 +84,13 @@ void RaiseStatusNotOK(const tensorflow::Status& status, swig_type_info *type) { } if (StatusNotOKError != Py_None) { - auto fullmsg_ptr = make_safe(_SwigString_FromString(fullmsg)); + auto fullmsg_ptr = make_safe(_SwigSimpleStr_FromString(fullmsg)); auto exception_ptr = make_safe(PyObject_CallFunctionObjArgs( StatusNotOKError, fullmsg_ptr.get(), NULL)); exception = exception_ptr.get(); if (exception) { auto pycode = make_safe(PyInt_FromLong(static_cast<long>(code))); - auto pymsg = make_safe(_SwigString_FromString(status.error_message())); + auto pymsg = make_safe(_SwigSimpleStr_FromString(status.error_message())); auto pystatus = make_safe(SWIG_NewPointerObj( SWIG_as_voidptr(new tensorflow::Status(status)), type, SWIG_POINTER_OWN)); PyObject_SetAttrString(exception, "code", pycode.get()); @@ -100,17 +115,9 @@ void RaiseStatusNotOK(const tensorflow::Status& status, swig_type_info *type) { %unignore tensorflow; %unignore tensorflow::lib; %unignore tensorflow::Status; -%unignore tensorflow::Status::Status; -%unignore tensorflow::Status::Status(tensorflow::error::Code, StringPiece); %unignore tensorflow::Status::~Status; -%unignore tensorflow::Status::code; -%unignore tensorflow::Status::ok; -%unignore tensorflow::Status::error_message; -%unignore tensorflow::Status::ToString; %ignore tensorflow::Status::operator=; -%rename(__str__) tensorflow::Status::ToString; - %include "tensorflow/core/public/status.h" %unignoreall diff --git a/tensorflow/python/lib/core/status_helper.i b/tensorflow/python/lib/core/status_helper.i index 2e01e79ebdd..34ffc951741 100644 --- a/tensorflow/python/lib/core/status_helper.i +++ b/tensorflow/python/lib/core/status_helper.i @@ -1,3 +1,18 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + // SWIG test helper for lib::tensorflow::Status %include "tensorflow/python/platform/base.i" diff --git a/tensorflow/python/lib/core/strings.i b/tensorflow/python/lib/core/strings.i index 7ee912c0f7c..14a03b35379 100644 --- a/tensorflow/python/lib/core/strings.i +++ b/tensorflow/python/lib/core/strings.i @@ -1,3 +1,18 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + // Wrapper functions to provide a scripting-language-friendly interface // to our string libraries. // @@ -20,6 +35,23 @@ %{ #include "tensorflow/core/lib/core/stringpiece.h" + +// Handles str in Python 2, bytes in Python 3. +// Returns true on success, false on failure. +bool _BytesToStringPiece(PyObject* obj, tensorflow::StringPiece* result) { + if (obj == Py_None) { + result->clear(); + } else { + char* ptr; + Py_ssize_t len; + if (PyBytes_AsStringAndSize(obj, &ptr, &len) == -1) { + // Python has raised an error (likely TypeError or UnicodeEncodeError). + return false; + } + result->set(ptr, len); + } + return true; +} %} %typemap(typecheck) tensorflow::StringPiece = char *; @@ -27,29 +59,13 @@ // "tensorflow::StringPiece" arguments must be specified as a 'str' or 'bytes' object. %typemap(in) tensorflow::StringPiece { - if ($input != Py_None) { - char * buf; - Py_ssize_t len; - if (PyBytes_AsStringAndSize($input, &buf, &len) == -1) { - // Python has raised an error (likely TypeError or UnicodeEncodeError). - SWIG_fail; - } - $1.set(buf, len); - } + if (!_BytesToStringPiece($input, &$1)) SWIG_fail; } // "const tensorflow::StringPiece&" arguments can be provided the same as // "tensorflow::StringPiece", whose typemap is defined above. %typemap(in) const tensorflow::StringPiece & (tensorflow::StringPiece temp) { - if ($input != Py_None) { - char * buf; - Py_ssize_t len; - if (PyBytes_AsStringAndSize($input, &buf, &len) == -1) { - // Python has raised an error (likely TypeError). - SWIG_fail; - } - temp.set(buf, len); - } + if (!_BytesToStringPiece($input, &temp)) SWIG_fail; $1 = &temp; } diff --git a/tensorflow/python/lib/io/py_record_reader.h b/tensorflow/python/lib/io/py_record_reader.h index c6da53aecdf..883e864bd44 100644 --- a/tensorflow/python/lib/io/py_record_reader.h +++ b/tensorflow/python/lib/io/py_record_reader.h @@ -37,7 +37,7 @@ class PyRecordReader { ~PyRecordReader(); // Attempt to get the next record at "current_offset()". If - // successful, returns true, and the record contents can be retrieve + // successful, returns true, and the record contents can be retrieved // with "this->record()". Otherwise, returns false. bool GetNext(); // Return the current record contents. Only valid after the preceding call diff --git a/tensorflow/python/lib/io/py_record_reader.i b/tensorflow/python/lib/io/py_record_reader.i index 19f911bd52d..5d35819d40b 100644 --- a/tensorflow/python/lib/io/py_record_reader.i +++ b/tensorflow/python/lib/io/py_record_reader.i @@ -1,3 +1,18 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + %nothread tensorflow::io::PyRecordReader::GetNext; %include "tensorflow/python/platform/base.i" diff --git a/tensorflow/python/lib/io/py_record_writer.i b/tensorflow/python/lib/io/py_record_writer.i index 20fe52c495f..6c9e13d695c 100644 --- a/tensorflow/python/lib/io/py_record_writer.i +++ b/tensorflow/python/lib/io/py_record_writer.i @@ -1,3 +1,18 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + %nothread tensorflow::io::PyRecordWriter::WriteRecord; %include "tensorflow/python/platform/base.i" diff --git a/tensorflow/python/lib/io/tf_record.py b/tensorflow/python/lib/io/tf_record.py index b91c0a9056a..2d7bb46458b 100644 --- a/tensorflow/python/lib/io/tf_record.py +++ b/tensorflow/python/lib/io/tf_record.py @@ -20,6 +20,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python import pywrap_tensorflow +from tensorflow.python.util import compat def tf_record_iterator(path): @@ -34,7 +35,7 @@ def tf_record_iterator(path): Raises: IOError: If `path` cannot be opened for reading. """ - reader = pywrap_tensorflow.PyRecordReader_New(path, 0) + reader = pywrap_tensorflow.PyRecordReader_New(compat.as_bytes(path), 0) if reader is None: raise IOError("Could not open %s." % path) while reader.GetNext(): @@ -62,7 +63,7 @@ class TFRecordWriter(object): Raises: IOError: If `path` cannot be opened for writing. """ - self._writer = pywrap_tensorflow.PyRecordWriter_New(path) + self._writer = pywrap_tensorflow.PyRecordWriter_New(compat.as_bytes(path)) if self._writer is None: raise IOError("Could not write to %s." % path) diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 9cccc8771fc..50f3facf2e8 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -65,10 +65,10 @@ import sys import tensorflow.python.platform import numpy as np +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util -from tensorflow.python.framework import types from tensorflow.python.ops import common_shapes from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import gen_math_ops @@ -460,9 +460,9 @@ def transpose(a, perm=None, name="transpose"): [3 6]] # Equivalently - tf.transpose(x perm=[0, 1]) ==> [[1 4] - [2 5] - [3 6]] + tf.transpose(x, perm=[1, 0]) ==> [[1 4] + [2 5] + [3 6]] # 'perm' is more useful for n-dimensional tensors, for n > 2 # 'x' is [[[1 2 3] @@ -502,7 +502,7 @@ def transpose(a, perm=None, name="transpose"): return ret -def zeros(shape, dtype=types.float32, name=None): +def zeros(shape, dtype=dtypes.float32, name=None): """Creates a tensor with all elements set to zero. This operation returns a tensor of type `dtype` with shape `shape` and @@ -528,7 +528,7 @@ def zeros(shape, dtype=types.float32, name=None): else: shape = ops.convert_to_tensor(shape, name="shape") output = fill(shape, constant(0, dtype=dtype), name=name) - assert output.dtype.base_dtype == types.as_dtype(dtype).base_dtype + assert output.dtype.base_dtype == dtypes.as_dtype(dtype).base_dtype return output @@ -594,12 +594,12 @@ def ones_like(tensor, dtype=None, name=None): return ones(ones_shape, dtype=dtype, name=name) -def zeros_initializer(shape, dtype=types.float32): +def zeros_initializer(shape, dtype=dtypes.float32): """An adaptor for zeros() to match the Initializer spec.""" return zeros(shape, dtype) -def ones(shape, dtype=types.float32, name=None): +def ones(shape, dtype=dtypes.float32, name=None): """Creates a tensor with all elements set to 1. This operation returns a tensor of type `dtype` with shape `shape` and all @@ -625,7 +625,7 @@ def ones(shape, dtype=types.float32, name=None): else: shape = ops.convert_to_tensor(shape, name="shape") output = fill(shape, constant(1, dtype=dtype), name=name) - assert output.dtype.base_dtype == types.as_dtype(dtype).base_dtype + assert output.dtype.base_dtype == dtypes.as_dtype(dtype).base_dtype return output diff --git a/tensorflow/python/ops/candidate_sampling_ops.py b/tensorflow/python/ops/candidate_sampling_ops.py index 0a3e068523d..6b8ecc13eac 100644 --- a/tensorflow/python/ops/candidate_sampling_ops.py +++ b/tensorflow/python/ops/candidate_sampling_ops.py @@ -311,7 +311,7 @@ def all_candidate_sampler(true_classes, num_true, num_sampled, unique, def compute_accidental_hits(true_classes, sampled_candidates, num_true, seed=None, name=None): - """Compute the ids of positions in sampled_candidates matching true_classes. + """Compute the position ids in `sampled_candidates` matching `true_classes`. In Candidate Sampling, this operation facilitates virtually removing sampled classes which happen to match target classes. This is done diff --git a/tensorflow/python/ops/clip_ops.py b/tensorflow/python/ops/clip_ops.py index da66610e57c..a2b39d6594c 100644 --- a/tensorflow/python/ops/clip_ops.py +++ b/tensorflow/python/ops/clip_ops.py @@ -22,8 +22,8 @@ import collections import six +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import types from tensorflow.python.ops import array_ops from tensorflow.python.ops import constant_op from tensorflow.python.ops import math_ops @@ -65,7 +65,7 @@ def clip_by_norm(t, clip_norm, name=None): """Clips tensor values to a maximum L2-norm. Given a tensor `t`, and a maximum clip value `clip_norm`, this operation - normalizes `t` so that its L2-norm is less than or equal to `clip_norm'. + normalizes `t` so that its L2-norm is less than or equal to `clip_norm`. Specifically, if the L2-norm is already less than or equal to `clip_norm`, then `t` is not modified. If the L2-norm is greater than `clip_norm`, then this operation returns a tensor of the same type and shape as `t` with its @@ -146,18 +146,18 @@ def clip_by_global_norm(t_list, clip_norm, use_norm=None, name=None): if you've already computed the global norm for `t_list`, you can specify the global norm with `use_norm`. - To perform the clipping, the values t_list[i] are set to: + To perform the clipping, the values `t_list[i]` are set to: - `t_list[i] * clip_norm / max(global_norm, clip_norm)` + t_list[i] * clip_norm / max(global_norm, clip_norm) where: - `global_norm = sqrt(sum([l2norm(t)**2 for t in t_list]))` + global_norm = sqrt(sum([l2norm(t)**2 for t in t_list])) If `clip_norm > global_norm` then the entries in `t_list` remain as they are, otherwise they're all shrunk by the global ratio. - Any of the entries of `t_list` that are of type None are ignored. + Any of the entries of `t_list` that are of type `None` are ignored. This is the correct way to perform gradient clipping (for example, see R. Pascanu, T. Mikolov, and Y. Bengio, "On the difficulty of training @@ -219,7 +219,7 @@ def clip_by_average_norm(t, clip_norm, name=None): Given a tensor `t`, and a maximum clip value `clip_norm`, this operation normalizes `t` so that its average L2-norm is less than or equal to - `clip_norm'. Specifically, if the average L2-norm is already less than or + `clip_norm`. Specifically, if the average L2-norm is already less than or equal to `clip_norm`, then `t` is not modified. If the average L2-norm is greater than `clip_norm`, then this operation returns a tensor of the same type and shape as `t` with its values set to: @@ -244,7 +244,7 @@ def clip_by_average_norm(t, clip_norm, name=None): # Calculate L2-norm per element, clip elements by ratio of clip_norm to # L2-norm per element - n_element = math_ops.cast(array_ops.size(t), types.float32) + n_element = math_ops.cast(array_ops.size(t), dtypes.float32) l2norm_inv = math_ops.rsqrt( math_ops.reduce_sum(t * t, math_ops.range(array_ops.rank(t)))) tclip = array_ops.identity( diff --git a/tensorflow/python/ops/common_shapes.py b/tensorflow/python/ops/common_shapes.py index b3decb2651f..c0615f1d3b7 100644 --- a/tensorflow/python/ops/common_shapes.py +++ b/tensorflow/python/ops/common_shapes.py @@ -130,10 +130,10 @@ def get2d_conv_output_size(input_height, input_width, filter_height, # Compute number of rows in the output, based on the padding. if input_height.value is None or filter_height.value is None: out_rows = None - elif padding_type == "VALID": + elif padding_type == b"VALID": out_rows = ((input_height.value - filter_height.value + row_stride) // row_stride) - elif padding_type == "SAME": + elif padding_type == b"SAME": out_rows = (input_height.value + row_stride - 1) // row_stride else: raise ValueError("Invalid value for padding: %r" % padding_type) @@ -141,10 +141,10 @@ def get2d_conv_output_size(input_height, input_width, filter_height, # Compute number of columns in the output, based on the padding. if input_width.value is None or filter_width.value is None: out_cols = None - elif padding_type == "VALID": + elif padding_type == b"VALID": out_cols = ((input_width.value - filter_width.value + col_stride) // col_stride) - elif padding_type == "SAME": + elif padding_type == b"SAME": out_cols = (input_width.value + col_stride - 1) // col_stride return out_rows, out_cols diff --git a/tensorflow/python/ops/constant_op.py b/tensorflow/python/ops/constant_op.py index 9cf5e99a1e4..f2aaad37a99 100644 --- a/tensorflow/python/ops/constant_op.py +++ b/tensorflow/python/ops/constant_op.py @@ -105,10 +105,10 @@ import tensorflow.python.platform import numpy as np from tensorflow.core.framework import attr_value_pb2 +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util -from tensorflow.python.framework import types def constant(value, dtype=None, shape=None, name="Const"): @@ -182,10 +182,10 @@ def _tensor_shape_tensor_conversion_function(s, dtype=None, name=None): raise ValueError( "Cannot convert a partially known TensorShape to a Tensor: %s" % s) if dtype is not None: - if dtype not in (types.int32, types.int64): + if dtype not in (dtypes.int32, dtypes.int64): raise TypeError("Cannot convert a TensorShape to dtype: %s" % dtype) else: - dtype = types.int32 + dtype = dtypes.int32 if name is None: name = "shape_as_tensor" return constant(s.as_list(), dtype=dtype, name=name) @@ -197,10 +197,10 @@ def _dimension_tensor_conversion_function(d, dtype=None, name=None): if d.value is None: raise ValueError("Cannot convert an unknown Dimension to a Tensor: %s" % d) if dtype is not None: - if dtype not in (types.int32, types.int64): + if dtype not in (dtypes.int32, dtypes.int64): raise TypeError("Cannot convert a TensorShape to dtype: %s" % dtype) else: - dtype = types.int32 + dtype = dtypes.int32 if name is None: name = "shape_as_tensor" return constant(d.value, dtype=dtype, name=name) diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index cd15d921cfb..8eb1bd79bff 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -69,9 +69,9 @@ from __future__ import print_function import six from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import types from tensorflow.python.ops import array_ops from tensorflow.python.ops import common_shapes from tensorflow.python.ops import constant_op @@ -443,7 +443,7 @@ def _GetRealValue(value): # The begin position of the slice at slice_index. slice_index = forward_ctxt.grad_context.index - b1 = array_ops.zeros(elem_rank_vec, dtype=types.int32) + b1 = array_ops.zeros(elem_rank_vec, dtype=dtypes.int32) b = array_ops.concat(0, [array_ops.expand_dims(slice_index, 0), b1]) # The slice at slice_index. @@ -1134,7 +1134,7 @@ def _AsTensorList(x, p): Returns: A list of Tensors or IndexedSlices. """ - if not isinstance(x, list) and not isinstance(x, _basetuple): + if not isinstance(x, (list, _basetuple)): x = [x] l = [] @@ -1249,7 +1249,10 @@ def group(*inputs, **kwargs): # 2-level tree. The root node is the returned NoOp node. # deps contains 1 NoOp node for each device. deps = [] - for dev in sorted(six.iterkeys(ops_on_device)): + def device_key(dev): + """A sort key that allows None to be compared to strings.""" + return "" if dev is None else dev + for dev in sorted(six.iterkeys(ops_on_device), key=device_key): deps.append(_GroupControlDeps(dev, ops_on_device[dev])) return _GroupControlDeps(None, deps, name=name) @@ -1332,7 +1335,7 @@ def fold(fn, elems, elem_shape, name=None): d0 = elem_shape[0] n = math_ops.div(s0, d0) b1 = array_ops.zeros(array_ops.expand_dims(array_ops.rank(elems) - 1, 0), - dtype=types.int32) + dtype=dtypes.int32) # Initialize the output with slice 0 b = array_ops.concat(0, [[0], b1]) o = array_ops.slice(elems, b, elem_shape) @@ -1374,7 +1377,7 @@ def case(pred_fn_pairs, default, exclusive=False, name="case"): Expressions: ``` - f1 = lambda: tf.onstant(17) + f1 = lambda: tf.constant(17) f2 = lambda: tf.constant(23) r = case([(tf.less(x, y), f1)], default=f2) ``` @@ -1428,7 +1431,7 @@ def case(pred_fn_pairs, default, exclusive=False, name="case"): if not isinstance(tup, _basetuple) or len(tup) != 2: raise TypeError("Each entry in pred_fn_pairs must be a 2-tuple") pred, fn = tup - if pred.dtype != types.bool: + if pred.dtype != dtypes.bool: raise TypeError("pred must be of type bool: %s", pred.name) if not callable(fn): raise TypeError("fn for pred %s must be callable." % pred.name) @@ -1468,7 +1471,7 @@ def case(pred_fn_pairs, default, exclusive=False, name="case"): # TODO(ebrevdo): Add Where() for DT_BOOL, replace with Size(Where(preds)) preds_c = array_ops.concat(0, preds, name="preds_c") num_true_conditions = math_ops.reduce_sum( - math_ops.cast(preds_c, types.int32), name="num_true_conds") + math_ops.cast(preds_c, dtypes.int32), name="num_true_conds") at_most_one_true_condition = math_ops.less( num_true_conditions, constant_op.constant(2, name="two_true_conds")) diff --git a/tensorflow/python/ops/data_flow_grad.py b/tensorflow/python/ops/data_flow_grad.py index 219a503fd7f..c46c29ed57e 100644 --- a/tensorflow/python/ops/data_flow_grad.py +++ b/tensorflow/python/ops/data_flow_grad.py @@ -19,8 +19,8 @@ from __future__ import division from __future__ import print_function from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import types from tensorflow.python.ops import array_ops from tensorflow.python.ops import constant_op from tensorflow.python.ops import data_flow_ops @@ -36,8 +36,8 @@ def _DynamicStitchGrads(op, grad): indices_grad = [None] * num_values def AsInt32(x): - return (x if op.inputs[0].dtype == types.int32 else - math_ops.cast(x, types.int32)) + return (x if op.inputs[0].dtype == dtypes.int32 else + math_ops.cast(x, dtypes.int32)) inputs = [AsInt32(op.inputs[i]) for i in xrange(num_values)] if isinstance(grad, ops.IndexedSlices): output_shape = array_ops.shape(op.outputs[0]) @@ -54,3 +54,8 @@ ops.NoGradient("QueueDequeue") ops.NoGradient("QueueDequeueMany") ops.NoGradient("QueueClose") ops.NoGradient("QueueSize") + +ops.NoGradient("Stack") +ops.NoGradient("StackPush") +ops.NoGradient("StackPop") +ops.NoGradient("StackClose") diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py index c50fad3211b..261193715c1 100644 --- a/tensorflow/python/ops/data_flow_ops.py +++ b/tensorflow/python/ops/data_flow_ops.py @@ -125,12 +125,12 @@ class QueueBase(object): A `QueueBase` object. Raises: - TypeError: when `queues` is not a list of `QueueBase` objects, + TypeError: When `queues` is not a list of `QueueBase` objects, or when the data types of `queues` are not all the same. """ if ((not queues) or (not isinstance(queues, list)) or - (not all([isinstance(x, QueueBase) for x in queues]))): + (not all(isinstance(x, QueueBase) for x in queues))): raise TypeError("A list of queues expected") dtypes = queues[0].dtypes @@ -458,6 +458,12 @@ ops.RegisterShape("QueueEnqueue")(common_shapes.unknown_shape) ops.RegisterShape("QueueEnqueueMany")(common_shapes.unknown_shape) +ops.RegisterShape("Stack")(common_shapes.scalar_shape) +ops.RegisterShape("StackPush")(common_shapes.unknown_shape) +ops.RegisterShape("StackPop")(common_shapes.unknown_shape) +ops.RegisterShape("StackClose")(common_shapes.unknown_shape) + + @ops.RegisterShape("QueueClose") def _ScalarToVoidShape(op): """Shape function for ops that take a scalar and produce no outputs.""" diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py index 9fb1cf99f28..ff39c3c7513 100644 --- a/tensorflow/python/ops/embedding_ops.py +++ b/tensorflow/python/ops/embedding_ops.py @@ -19,8 +19,8 @@ from __future__ import division from __future__ import print_function from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import types from tensorflow.python.ops import array_ops from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import math_ops @@ -43,7 +43,7 @@ def embedding_lookup(params, ids, name=None): Args: params: A list of tensors with the same shape and type. - ids: A `Tensor` with type `int32` containing the ids to be looked + ids: A `Tensor` with type `int32` or `int64` containing the ids to be looked up in `params`. name: A name for the operation (optional). @@ -69,8 +69,8 @@ def embedding_lookup(params, ids, name=None): original_indices = math_ops.range(array_ops.size(flat_ids)) # Compute flat_ids % partitions for each id ids_mod_p = flat_ids % np - if ids_mod_p.dtype != types.int32: - ids_mod_p = math_ops.cast(ids_mod_p, types.int32) + if ids_mod_p.dtype != dtypes.int32: + ids_mod_p = math_ops.cast(ids_mod_p, dtypes.int32) # Partition single list of ids based on ids % np into np separate lists plist = data_flow_ops.dynamic_partition(flat_ids, ids_mod_p, np) # Similarly, partition the original indices. @@ -178,8 +178,8 @@ def embedding_lookup_sparse(params, sp_ids, sp_weights, with ops.op_scope(params + [sp_ids], name, "embedding_lookup_sparse") as name: segment_ids = sp_ids.indices[:, 0] - if segment_ids.dtype != types.int32: - segment_ids = math_ops.cast(segment_ids, types.int32) + if segment_ids.dtype != dtypes.int32: + segment_ids = math_ops.cast(segment_ids, dtypes.int32) ids = sp_ids.values if ignore_weights: diff --git a/tensorflow/python/ops/gradients.py b/tensorflow/python/ops/gradients.py index 72c7bc2499c..b790d9af6c9 100644 --- a/tensorflow/python/ops/gradients.py +++ b/tensorflow/python/ops/gradients.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Implements the graph generation for computation of gradients.""" from __future__ import absolute_import @@ -27,16 +26,17 @@ import tensorflow.python.platform import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util -from tensorflow.python.framework import types # pylint: disable=unused-import from tensorflow.python.ops import array_grad from tensorflow.python.ops import array_ops from tensorflow.python.ops import constant_op from tensorflow.python.ops import control_flow_grad from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import image_grad from tensorflow.python.ops import logging_ops from tensorflow.python.ops import linalg_grad from tensorflow.python.ops import math_grad @@ -45,7 +45,6 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.platform import logging - # Warn the user if we convert a sparse representation to dense with at # least this number of elements. _LARGE_SPARSE_NUM_ELEMENTS = 100000000 @@ -69,8 +68,8 @@ def _IndexedSlicesToTensor(value, dtype=None, name=None): """ if dtype and not dtype.is_compatible_with(value.dtype): raise ValueError( - "Tensor conversion requested dtype %s for IndexedSlices with dtype %s" - % (dtype.name, value.dtype.name)) + "Tensor conversion requested dtype %s for IndexedSlices with dtype %s" % + (dtype.name, value.dtype.name)) if value.dense_shape is None: raise ValueError( "Tensor conversion requested for IndexedSlices without dense_shape: %s" @@ -88,11 +87,14 @@ def _IndexedSlicesToTensor(value, dtype=None, name=None): warnings.warn( "Converting sparse IndexedSlices to a dense Tensor of unknown shape. " "This may consume a large amount of memory.") - return math_ops.unsorted_segment_sum( - value.values, value.indices, value.dense_shape[0], name=name) + return math_ops.unsorted_segment_sum(value.values, + value.indices, + value.dense_shape[0], + name=name) -ops.register_tensor_conversion_function(ops.IndexedSlices, _IndexedSlicesToTensor) +ops.register_tensor_conversion_function(ops.IndexedSlices, + _IndexedSlicesToTensor) def _MarkReachedOps(from_ops, reached_ops): @@ -236,14 +238,16 @@ def _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops): y = ys[i] if grad_y is None: with ops.device(_GetGradsDevice(y.op, colocate_gradients_with_ops)): - grad_ys[i] = array_ops.fill(array_ops.shape(y), - constant_op.constant(1, dtype=y.dtype)) + grad_ys[i] = array_ops.fill( + array_ops.shape(y), + constant_op.constant(1, + dtype=y.dtype)) else: if grad_y.dtype != y.dtype: raise ValueError("Y and ys_grad must be of the same type, " "not y: %s, ys_grad: %s " % - (types.as_dtype(y.dtype).name, - types.as_dtype(grad_y.dtype).name)) + (dtypes.as_dtype(y.dtype).name, + dtypes.as_dtype(grad_y.dtype).name)) return grad_ys @@ -265,11 +269,10 @@ def _VerifyGeneratedGradients(grads, op): inp = op.inputs[i] if grad is not None: if not grad.dtype.is_compatible_with(inp.dtype): - raise ValueError( - "Gradient type %s generated for op %s does " - "not match input type %s" % - (types.as_dtype(grad.dtype).name, op.node_def, - types.as_dtype(inp.dtype).name)) + raise ValueError("Gradient type %s generated for op %s does " + "not match input type %s" % + (dtypes.as_dtype(grad.dtype).name, op.node_def, + dtypes.as_dtype(inp.dtype).name)) def _StopOps(from_ops, pending_count): @@ -301,7 +304,10 @@ def _StopOps(from_ops, pending_count): return stop_ops -def gradients(ys, xs, grad_ys=None, name="gradients", +def gradients(ys, + xs, + grad_ys=None, + name="gradients", colocate_gradients_with_ops=False, gate_gradients=False, aggregation_method=None): @@ -319,7 +325,7 @@ def gradients(ys, xs, grad_ys=None, name="gradients", `grad_ys` is a list of tensors of the same length as `ys` that holds the initial gradients for each y in `ys`. When `grad_ys` is None, we fill in a tensor of '1's of the shape of y for each y in `ys`. A - user can provide their own initial 'grad_ys` to compute the + user can provide their own initial `grad_ys` to compute the derivatives using a different initial gradient for each y (e.g., if one wanted to weight the gradient differently for each value in each y). @@ -369,8 +375,8 @@ def gradients(ys, xs, grad_ys=None, name="gradients", # to the xs. to_ops = [t.op for t in ys] from_ops = [t.op for t in xs] - pending_count, has_control_flow = _PendingCount( - ops.get_default_graph(), to_ops, from_ops) + pending_count, has_control_flow = _PendingCount(ops.get_default_graph(), + to_ops, from_ops) # Iterate over the collected ops. # @@ -418,9 +424,9 @@ def gradients(ys, xs, grad_ys=None, name="gradients", # output, it means that the cost does not depend on output[i], # therefore dC/doutput[i] is 0. for i, out_grad in enumerate(out_grads): - if (not out_grad - and types.as_dtype(op.outputs[i].dtype).base_dtype in ( - types.float32, types.float64)): + if (not out_grad and + dtypes.as_dtype(op.outputs[i].dtype).base_dtype in + (dtypes.float32, dtypes.float64)): # Only floating-point outputs get a zero gradient. Gradient # functions should ignore the gradient for other outputs. out_grads[i] = array_ops.zeros_like(op.outputs[i]) @@ -485,7 +491,8 @@ def _GetGrad(grads, t): """Gets gradient for tensor "t".""" op = t.op op_grads = grads.get(op) - if not op_grads: return None + if not op_grads: + return None t_grad = op_grads[t.value_index] assert not isinstance(t_grad, list), ( "gradients list should have been aggregated by now.") @@ -622,8 +629,8 @@ def _AggregatedGrads(grads, op, has_control_flow, aggregation_method=None): # indices. out_grads[i] = ops.IndexedSlices( array_ops.concat(0, [x.values for x in out_grad]), - array_ops.concat(0, [x.indices for x in out_grad]), - out_grad[0].dense_shape) + array_ops.concat(0, [x.indices + for x in out_grad]), out_grad[0].dense_shape) else: out_grads[i] = [] return out_grads diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py index 9a87319697d..5afc8a779b2 100644 --- a/tensorflow/python/ops/gradients_test.py +++ b/tensorflow/python/ops/gradients_test.py @@ -24,9 +24,9 @@ import tensorflow.python.platform import numpy as np +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util -from tensorflow.python.framework import types # pylint: disable=unused-import from tensorflow.python.ops import array_grad from tensorflow.python.ops import array_ops @@ -68,7 +68,7 @@ def _OpsBetween(graph, to_ops, from_ops): reached_ops[op._id] = True gradients._MarkReachedOps(from_ops, reached_ops) between_ops = gradients._GatherInputs(to_ops, reached_ops) - between_ops.sort(lambda x, y: y._id - x._id) + between_ops.sort(key=lambda x: -x._id) return between_ops @@ -246,13 +246,13 @@ class GradientsTest(test_util.TensorFlowTestCase): @ops.RegisterGradient("TestOp") def _TestOpGrad(op, float_grad, string_grad): """Gradient function for TestOp.""" - self.assertEquals(float_grad.dtype, types.float32) + self.assertEquals(float_grad.dtype, dtypes.float32) self.assertFalse(string_grad) return float_grad ops.RegisterShape("TestOp")(None) c = constant(1.0) - x, y = g.create_op("TestOp", [c], [types.float32, types.string]).outputs + x, y = g.create_op("TestOp", [c], [dtypes.float32, dtypes.string]).outputs z = x * 2.0 w = z * 3.0 grads = gradients.gradients(z, [c]) @@ -314,7 +314,7 @@ class IndexedSlicesToTensorTest(test_util.TensorFlowTestCase): c = constant_op.constant(np_val) c_sparse = math_ops._as_indexed_slices(c) c_sparse = ops.IndexedSlices( - c_sparse.values, math_ops.cast(c_sparse.indices, types.int64), + c_sparse.values, math_ops.cast(c_sparse.indices, dtypes.int64), c_sparse.dense_shape) self.assertAllEqual(np_val.shape, c_sparse.dense_shape.eval()) c_dense = math_ops.mul(c_sparse, 1.0) @@ -322,16 +322,16 @@ class IndexedSlicesToTensorTest(test_util.TensorFlowTestCase): def testWarnings(self): # Smaller than the threshold: no warning. - c_sparse = ops.IndexedSlices(array_ops.placeholder(types.float32), - array_ops.placeholder(types.int32), + c_sparse = ops.IndexedSlices(array_ops.placeholder(dtypes.float32), + array_ops.placeholder(dtypes.int32), constant([4, 4, 4, 4])) with warnings.catch_warnings(record=True) as w: math_ops.mul(c_sparse, 1.0) self.assertEqual(0, len(w)) # Greater than or equal to the threshold: warning. - c_sparse = ops.IndexedSlices(array_ops.placeholder(types.float32), - array_ops.placeholder(types.int32), + c_sparse = ops.IndexedSlices(array_ops.placeholder(dtypes.float32), + array_ops.placeholder(dtypes.int32), constant([100, 100, 100, 100])) with warnings.catch_warnings(record=True) as w: math_ops.mul(c_sparse, 1.0) @@ -341,9 +341,9 @@ class IndexedSlicesToTensorTest(test_util.TensorFlowTestCase): in str(w[0].message)) # Unknown dense shape: warning. - c_sparse = ops.IndexedSlices(array_ops.placeholder(types.float32), - array_ops.placeholder(types.int32), - array_ops.placeholder(types.int32)) + c_sparse = ops.IndexedSlices(array_ops.placeholder(dtypes.float32), + array_ops.placeholder(dtypes.int32), + array_ops.placeholder(dtypes.int32)) with warnings.catch_warnings(record=True) as w: math_ops.mul(c_sparse, 1.0) self.assertEqual(1, len(w)) diff --git a/tensorflow/python/ops/image_grad.py b/tensorflow/python/ops/image_grad.py new file mode 100644 index 00000000000..55f1c74f0e6 --- /dev/null +++ b/tensorflow/python/ops/image_grad.py @@ -0,0 +1,57 @@ +# Copyright 2015 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Contains Gradient functions for image ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import gen_image_ops + + +@ops.RegisterGradient("ResizeNearestNeighbor") +def _ResizeNearestNeighborGrad(op, grad): + """The derivatives for nearest neighbor resizing. + + Args: + op: The ResizeNearestNeighbor op. + grad: The tensor representing the gradient w.r.t. the output. + + Returns: + The gradients w.r.t. the input and the output. + """ + grads = gen_image_ops.resize_nearest_neighbor_grad( + grad, op.inputs[0].get_shape()[1:3]) + return [grads, None] + + +@ops.RegisterShape("ResizeNearestNeighborGrad") +def _ResizeShape(op): + """Shape function for the resize grad ops.""" + input_shape = op.inputs[0].get_shape().with_rank(4) + size = tensor_util.ConstantValue(op.inputs[1]) + if size is not None: + height = size[0] + width = size[1] + else: + height = None + width = None + return [ + tensor_shape.TensorShape([input_shape[0], height, width, input_shape[3]]) + ] diff --git a/tensorflow/python/ops/image_grad_test.py b/tensorflow/python/ops/image_grad_test.py new file mode 100644 index 00000000000..488e741e5b4 --- /dev/null +++ b/tensorflow/python/ops/image_grad_test.py @@ -0,0 +1,85 @@ +# Copyright 2015 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Python ops defined in image_grad.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=g-bad-import-order, +# pylint: disable=unused-import +import tensorflow.python.platform +from tensorflow.python.kernel_tests import gradient_checker as gc + +import numpy as np +import tensorflow as tf +# pylint: enable=g-bad-import-order +# pylint: enable=unused-import + + +class ResizeNearestNeighborOpTest(tf.test.TestCase): + + def testShapeIsCorrectAfterOp(self): + in_shape = [1, 2, 2, 1] + out_shape = [1, 4, 6, 1] + + x = np.arange(0, 4).reshape(in_shape).astype(np.float32) + + with self.test_session() as sess: + input_tensor = tf.constant(x, shape=in_shape) + resize_out = tf.image.resize_nearest_neighbor(input_tensor, + out_shape[1:3]) + self.assertEqual(out_shape, list(resize_out.get_shape())) + + resize_out = sess.run(resize_out) + self.assertEqual(out_shape, list(resize_out.shape)) + + def testGradFromResizeToLargerInBothDims(self): + in_shape = [1, 2, 3, 1] + out_shape = [1, 4, 6, 1] + + x = np.arange(0, 6).reshape(in_shape).astype(np.float32) + + with self.test_session(): + input_tensor = tf.constant(x, shape=in_shape) + resize_out = tf.image.resize_nearest_neighbor(input_tensor, + out_shape[1:3]) + err = gc.ComputeGradientError(input_tensor, + in_shape, + resize_out, + out_shape, + x_init_value=x) + self.assertLess(err, 1e-3) + + def testGradFromResizeToSmallerInBothDims(self): + in_shape = [1, 4, 6, 1] + out_shape = [1, 2, 3, 1] + + x = np.arange(0, 24).reshape(in_shape).astype(np.float32) + + with self.test_session(): + input_tensor = tf.constant(x, shape=in_shape) + resize_out = tf.image.resize_nearest_neighbor(input_tensor, + out_shape[1:3]) + err = gc.ComputeGradientError(input_tensor, + in_shape, + resize_out, + out_shape, + x_init_value=x) + self.assertLess(err, 1e-3) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/python/ops/image_ops.py b/tensorflow/python/ops/image_ops.py index a95e8e9765d..7be02b220f6 100644 --- a/tensorflow/python/ops/image_ops.py +++ b/tensorflow/python/ops/image_ops.py @@ -55,10 +55,6 @@ image = tf.image.decode_jpeg(...) resized_image = tf.image.resize_bilinear(image, [299, 299]) ``` -<i>Maybe refer to the Queue examples that show how to add images to a Queue -after resizing them to a fixed size, and how to dequeue batches of resized -images from the Queue.</i> - @@resize_images @@resize_area @@ -109,11 +105,11 @@ import math import tensorflow.python.platform +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util -from tensorflow.python.framework import types from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops from tensorflow.python.ops import common_shapes @@ -354,7 +350,7 @@ def crop_to_bounding_box(image, offset_height, offset_width, target_height, This op cuts a rectangular part out of `image`. The top-left corner of the returned image is at `offset_height, offset_width` in `image`, and its lower-right corner is at - `offset_height + target_height, offset_width + target_width'. + `offset_height + target_height, offset_width + target_width`. Args: image: 3-D tensor with shape `[height, width, channels]` @@ -554,7 +550,7 @@ def per_image_whitening(image): height, width, depth = _ImageDimensions(image) num_pixels = height * width * depth - image = math_ops.cast(image, dtype=types.float32) + image = math_ops.cast(image, dtype=dtypes.float32) image_mean = math_ops.reduce_mean(image) variance = (math_ops.reduce_mean(math_ops.square(image)) - @@ -592,7 +588,7 @@ def random_brightness(image, max_delta, seed=None): 3-D tensor of images of shape `[height, width, channels]` Raises: - ValueError: if max_delta is negative. + ValueError: if `max_delta` is negative. """ _Check3DImage(image) @@ -664,8 +660,8 @@ def adjust_brightness(image, delta, min_value=None, max_value=None): with ops.op_scope([image, delta, min_value, max_value], None, 'adjust_brightness') as name: adjusted = math_ops.add( - math_ops.cast(image, types.float32), - math_ops.cast(delta, types.float32), + math_ops.cast(image, dtypes.float32), + math_ops.cast(delta, dtypes.float32), name=name) if image.dtype.is_integer: rounded = math_ops.round(adjusted) diff --git a/tensorflow/python/ops/init_ops.py b/tensorflow/python/ops/init_ops.py index 141fb8c13ae..8c52a232cb2 100644 --- a/tensorflow/python/ops/init_ops.py +++ b/tensorflow/python/ops/init_ops.py @@ -19,7 +19,7 @@ from __future__ import division from __future__ import print_function import math -from tensorflow.python.framework import types +from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import constant_op from tensorflow.python.ops import math_ops @@ -38,7 +38,7 @@ def constant_initializer(value=0.0): Returns: An initializer that generates tensors with a single value. """ - def _initializer(shape, dtype=types.float32): + def _initializer(shape, dtype=dtypes.float32): return constant_op.constant(value, dtype=dtype, shape=shape) return _initializer @@ -57,7 +57,7 @@ def random_uniform_initializer(minval=0.0, maxval=1.0, seed=None): Returns: An initializer that generates tensors with a uniform distribution. """ - def _initializer(shape, dtype=types.float32): + def _initializer(shape, dtype=dtypes.float32): return random_ops.random_uniform(shape, minval, maxval, dtype, seed=seed) return _initializer @@ -76,7 +76,7 @@ def random_normal_initializer(mean=0.0, stddev=1.0, seed=None): Returns: An initializer that generates tensors with a normal distribution. """ - def _initializer(shape, dtype=types.float32): + def _initializer(shape, dtype=dtypes.float32): return random_ops.random_normal(shape, mean, stddev, dtype, seed=seed) return _initializer @@ -101,7 +101,7 @@ def truncated_normal_initializer(mean=0.0, stddev=1.0, seed=None): An initializer that generates tensors with a truncated normal distribution. """ - def _initializer(shape, dtype=types.float32): + def _initializer(shape, dtype=dtypes.float32): return random_ops.truncated_normal(shape, mean, stddev, dtype, seed=seed) return _initializer @@ -132,7 +132,7 @@ def uniform_unit_scaling_initializer(factor=1.0, seed=None): Returns: An initializer that generates tensors with unit variance. """ - def _initializer(shape, dtype=types.float32): + def _initializer(shape, dtype=dtypes.float32): input_size = 1.0 # Estimating input size is not possible to do perfectly, but we try. # The estimate, obtained by multiplying all dimensions but the last one, @@ -146,7 +146,7 @@ def uniform_unit_scaling_initializer(factor=1.0, seed=None): # TODO(vrv): Unhide when we are ready to expose this publicly. -def _random_walk(shape, nonlinearity, dtype=types.float32, seed=None, +def _random_walk(shape, nonlinearity, dtype=dtypes.float32, seed=None, name="random_walk"): """Create a random tensor such that backprop neither vanishes nor explodes. @@ -200,7 +200,7 @@ class _RandomWalkInitializer(object): self._nonlinearity = nonlinearity self._seed = seed - def __call__(self, shape, dtype=types.float32): + def __call__(self, shape, dtype=dtypes.float32): """Generate a tensor used to initialize a variable.""" return random_ops._random_walk(shape, self._nonlinearity, dtype, seed=self._seed) diff --git a/tensorflow/python/ops/io_ops.py b/tensorflow/python/ops/io_ops.py index 69d4ee6266a..97b491bd539 100644 --- a/tensorflow/python/ops/io_ops.py +++ b/tensorflow/python/ops/io_ops.py @@ -121,9 +121,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import types from tensorflow.python.ops import common_shapes from tensorflow.python.ops import gen_io_ops # pylint: disable=wildcard-import @@ -183,7 +183,7 @@ def _restore_slice(file_pattern, tensor_name, shape_and_slice, tensor_type, Returns: A tensor of type "tensor_type". """ - base_type = types.as_dtype(tensor_type).base_dtype + base_type = dtypes.as_dtype(tensor_type).base_dtype return gen_io_ops._restore_slice( file_pattern, tensor_name, shape_and_slice, base_type, preferred_shard, name=name) diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py index 9f9e0943eac..12444916ce0 100644 --- a/tensorflow/python/ops/math_grad.py +++ b/tensorflow/python/ops/math_grad.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import types from tensorflow.python.ops import array_ops from tensorflow.python.ops import constant_op from tensorflow.python.ops import data_flow_ops @@ -288,7 +288,7 @@ def _MulGrad(op, grad): sx = array_ops.shape(x) sy = array_ops.shape(y) rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy) - if x.dtype.base_dtype == types.complex64: + if x.dtype.base_dtype == dtypes.complex64: return (array_ops.reshape(math_ops.reduce_sum(grad * math_ops.conj(y), rx), sx), array_ops.reshape(math_ops.reduce_sum(math_ops.conj(x) * grad, ry), sy)) else: @@ -412,29 +412,14 @@ def _SparseMatMulGrad(op, grad): assert t1 in is_sparse and t2 in is_sparse t1_sparse = is_sparse[t1] t2_sparse = is_sparse[t2] - if not t1_sparse and not t2_sparse: - return math_ops.matmul(t1, t2, - transpose_a=transpose_a, - transpose_b=transpose_b) - transpose_out = False - if not t1_sparse: - transpose_out = True - t1, t2 = t2, t1 - t1_sparse, t2_sparse = t2_sparse, t1_sparse - assert t1_sparse - transpose_a, transpose_b = not transpose_b, not transpose_a - if transpose_b: t2 = array_ops.transpose(t2) transpose_b = False - m = math_ops.matmul(t1, t2, - transpose_a=transpose_a, - transpose_b=transpose_b, - a_is_sparse=t1_sparse, - b_is_sparse=t2_sparse) - if transpose_out: - m = array_ops.transpose(m) - return m + return math_ops.matmul(t1, t2, + transpose_a=transpose_a, + transpose_b=transpose_b, + a_is_sparse=t1_sparse, + b_is_sparse=t2_sparse) if not t_a and not t_b: return (_SparseMatMul(grad, op.inputs[1], transpose_b=True), @@ -515,7 +500,7 @@ def _ConjGrad(_, grad): @ops.RegisterGradient("Cast") def _CastGrad(op, grad): - t = [types.float32, types.float64, types.bfloat16] + t = [dtypes.float32, dtypes.float64, dtypes.bfloat16] src_type = op.inputs[0].dtype.base_dtype dst_type = grad.dtype.base_dtype if src_type in t and dst_type in t: diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 29c2fcf4324..2555b4ba178 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -156,10 +156,10 @@ import tensorflow.python.platform import numpy as np import six.moves +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util -from tensorflow.python.framework import types from tensorflow.python.ops import array_ops from tensorflow.python.ops import common_shapes from tensorflow.python.ops import gen_math_ops @@ -196,7 +196,7 @@ def abs(x, name=None): """ with ops.op_scope([x], name, "Abs") as name: x = ops.convert_to_tensor(x, name="x") - if x.dtype == types.complex64: + if x.dtype == dtypes.complex64: return gen_math_ops.complex_abs(x, name=name) return gen_math_ops._abs(x, name=name) @@ -333,7 +333,7 @@ def to_float(x, name="ToFloat"): Raises: TypeError: If `x` cannot be cast to the `float32`. """ - return cast(x, types.float32, name=name) + return cast(x, dtypes.float32, name=name) def to_double(x, name="ToDouble"): @@ -349,7 +349,7 @@ def to_double(x, name="ToDouble"): Raises: TypeError: If `x` cannot be cast to the `float64`. """ - return cast(x, types.float64, name=name) + return cast(x, dtypes.float64, name=name) def to_int32(x, name="ToInt32"): @@ -365,7 +365,7 @@ def to_int32(x, name="ToInt32"): Raises: TypeError: If `x` cannot be cast to the `int32`. """ - return cast(x, types.int32, name=name) + return cast(x, dtypes.int32, name=name) def to_int64(x, name="ToInt64"): @@ -381,7 +381,7 @@ def to_int64(x, name="ToInt64"): Raises: TypeError: If `x` cannot be cast to the `int64`. """ - return cast(x, types.int64, name=name) + return cast(x, dtypes.int64, name=name) def to_bfloat16(x, name="ToBFloat16"): @@ -397,7 +397,7 @@ def to_bfloat16(x, name="ToBFloat16"): Raises: TypeError: If `x` cannot be cast to the `bfloat16`. """ - return cast(x, types.bfloat16, name=name) + return cast(x, dtypes.bfloat16, name=name) ops.Tensor._override_operator("__neg__", neg) @@ -437,14 +437,14 @@ def _OverrideBinaryOperatorHelper(func, op_name): # Conversion table for __truediv__. None entries mean no conversion required. _TRUEDIV_TABLE = { - types.uint8: types.float32, - types.int8: types.float32, - types.int16: types.float32, - types.int32: types.float64, - types.int64: types.float64, - types.float32: None, - types.float64: None, - types.complex64: None, + dtypes.uint8: dtypes.float32, + dtypes.int8: dtypes.float32, + dtypes.int16: dtypes.float32, + dtypes.int32: dtypes.float64, + dtypes.int64: dtypes.float64, + dtypes.float32: None, + dtypes.float64: None, + dtypes.complex64: None, } @@ -891,7 +891,7 @@ def matmul(a, b, with ops.op_scope([a, b], name, "MatMul") as name: a = ops.convert_to_tensor(a, name="a") b = ops.convert_to_tensor(b, name="b") - if a.dtype == types.float32 and (a_is_sparse or b_is_sparse): + if a.dtype == dtypes.float32 and (a_is_sparse or b_is_sparse): return sparse_matmul(a, b, transpose_a=transpose_a, transpose_b=transpose_b, @@ -953,14 +953,14 @@ def _as_indexed_slices_list(inputs): raise TypeError("Expected a list or tuple, not a %s" % type(inputs)) outputs = [_as_indexed_slices(i) for i in inputs] with_int32_index = [o.indices for o in outputs - if o.indices.dtype == types.int32] + if o.indices.dtype == dtypes.int32] if not with_int32_index or len(with_int32_index) == len(outputs): return outputs casted_outputs = [] for o in outputs: - if o.indices.dtype == types.int32: + if o.indices.dtype == dtypes.int32: casted_outputs.append( - ops.IndexedSlices(o.values, cast(o.indices, types.int64), + ops.IndexedSlices(o.values, cast(o.indices, dtypes.int64), o.dense_shape)) else: casted_outputs.append(o) diff --git a/tensorflow/python/ops/nn.py b/tensorflow/python/ops/nn.py index 749faaf73a5..17160f909e0 100644 --- a/tensorflow/python/ops/nn.py +++ b/tensorflow/python/ops/nn.py @@ -204,9 +204,9 @@ from __future__ import division from __future__ import print_function from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import types from tensorflow.python.ops import array_ops from tensorflow.python.ops import candidate_sampling_ops from tensorflow.python.ops import constant_op @@ -360,7 +360,7 @@ def zero_fraction(value, name=None): value = ops.convert_to_tensor(value, name="value") zero = constant_op.constant(0, dtype=value.dtype, name="zero") return math_ops.reduce_mean(math_ops.cast(math_ops.equal(value, zero), - types.float32)) + dtypes.float32)) def dropout(x, keep_prob, noise_shape=None, seed=None, name=None): @@ -633,35 +633,42 @@ def _compute_sampled_logits(weights, biases, inputs, labels, num_sampled, sum to 1 per-example. Args: - weights: tensor of label embeddings with shape = [num_classes, dim]. - biases: tensor of num_classes label biases - inputs: tensor with shape = [batch_size, dim] corresponding to forward - activations of the input network - labels: int tensor with shape [batch_size, num_true] - num_sampled: number of label classes to sample per batch - num_classes: number of possible label classes in the data (e.g. vocab size) - num_true: number of target classes per example (default: 1) + weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor` + objects whose concatenation along dimension 0 has shape + `[num_classes, dim]`. The (possibly-sharded) class embeddings. + biases: A `Tensor` of shape `[num_classes]`. The class biases. + inputs: A `Tensor` of shape `[batch_size, dim]`. The forward + activations of the input network. + labels: A `Tensor` of type `int64` and shape `[batch_size, + num_true]`. The target classes. Note that this format differs from + the `labels` argument of `nn.softmax_cross_entropy_with_logits`. + num_sampled: An `int`. The number of classes to randomly sample per batch. + num_classes: An `int`. The number of possible classes. + num_true: An `int`. The number of target classes per training example. sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`, - `sampled_expected_count`) returned by a `*_candidate_sampler` function - to use (if None, we default to `log_uniform_candidate_sampler`) - subtract_log_q: subtract the log expected count of the labels in the sample - to get the logits of the true labels (default: True) - Turn off for Negative Sampling. - remove_accidental_hits: whether to remove "accidental hits" where a sampled - label equals the true labels (bool, default: False) + `sampled_expected_count`) returned by a `*_candidate_sampler` function. + (if None, we default to `log_uniform_candidate_sampler`) + subtract_log_q: A `bool`. whether to subtract the log expected count of + the labels in the sample to get the logits of the true labels. + Default is True. Turn off for Negative Sampling. + remove_accidental_hits: A `bool`. whether to remove "accidental hits" + where a sampled class equals one of the target classes. Default is + False. name: A name for the operation (optional). - Returns: - out_logits, out_labels: tensors with shape - `[batch_size, num_true + num_sampled]` for passing to either - `sigmoid_cross_entropy_with_logits` (NCE) - or `softmax_cross_entropy_with_logits` (sampled softmax). + out_logits, out_labels: `Tensor` objects each with shape + `[batch_size, num_true + num_sampled]`, for passing to either + `nn.sigmoid_cross_entropy_with_logits` (NCE) or + `nn.softmax_cross_entropy_with_logits` (sampled softmax). """ + if not isinstance(weights, list): + weights = [weights] + with ops.op_scope( - [weights, biases, inputs, labels], name, "compute_sampled_logits"): - if labels.dtype != types.int64: - labels = math_ops.cast(labels, types.int64) + weights + [biases, inputs, labels], name, "compute_sampled_logits"): + if labels.dtype != dtypes.int64: + labels = math_ops.cast(labels, dtypes.int64) labels_flat = array_ops.reshape(labels, [-1]) # Sample the negative labels. @@ -726,7 +733,7 @@ def _compute_sampled_logits(weights, biases, inputs, labels, num_sampled, # This is how SparseToDense expects the indices. acc_indices_2d = array_ops.reshape(acc_indices, [-1, 1]) acc_ids_2d_int32 = array_ops.reshape(math_ops.cast( - acc_ids, types.int32), [-1, 1]) + acc_ids, dtypes.int32), [-1, 1]) sparse_indices = array_ops.concat( 1, [acc_indices_2d, acc_ids_2d_int32], "sparse_indices") # Create sampled_logits_shape = [batch_size, num_sampled] @@ -777,17 +784,19 @@ def nce_loss(weights, biases, inputs, labels, num_sampled, num_classes, with an otherwise unused class. Args: - weights: A `Tensor` of shape [num_classes, dim]. The class embeddings. - biases: A `Tensor` of shape [num_classes]. The class biases. - inputs: A `Tensor` of shape [batch_size, dim]. The forward + weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor` + objects whose concatenation along dimension 0 has shape + [num_classes, dim]. The (possibly-sharded) class embeddings. + biases: A `Tensor` of shape `[num_classes]`. The class biases. + inputs: A `Tensor` of shape `[batch_size, dim]`. The forward activations of the input network. labels: A `Tensor` of type `int64` and shape `[batch_size, - num_true]`. The target classes. + num_true]`. The target classes. num_sampled: An `int`. The number of classes to randomly sample per batch. num_classes: An `int`. The number of possible classes. num_true: An `int`. The number of target classes per training example. - sampled_values: a tuple of `(sampled_candidates, true_expected_count, - sampled_expected_count)` returned by a `*_candidate_sampler` function. + sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`, + `sampled_expected_count`) returned by a `*_candidate_sampler` function. (if None, we default to `log_uniform_candidate_sampler`) remove_accidental_hits: A `bool`. Whether to remove "accidental hits" where a sampled class equals one of the target classes. If set to @@ -799,7 +808,7 @@ def nce_loss(weights, biases, inputs, labels, num_sampled, num_classes, name: A name for the operation (optional). Returns: - A batch_size 1-D tensor of per-example NCE losses. + A `batch_size` 1-D tensor of per-example NCE losses. """ logits, labels = _compute_sampled_logits( weights, biases, inputs, labels, num_sampled, num_classes, @@ -838,18 +847,20 @@ def sampled_softmax_loss(weights, biases, inputs, labels, num_sampled, Also see Section 3 of http://arxiv.org/abs/1412.2007 for the math. Args: - weights: A `Tensor` of shape [num_classes, dim]. The class embeddings. - biases: A `Tensor` of shape [num_classes]. The class biases. - inputs: A `Tensor` of shape [batch_size, dim]. The forward + weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor` + objects whose concatenation along dimension 0 has shape + [num_classes, dim]. The (possibly-sharded) class embeddings. + biases: A `Tensor` of shape `[num_classes]`. The class biases. + inputs: A `Tensor` of shape `[batch_size, dim]`. The forward activations of the input network. labels: A `Tensor` of type `int64` and shape `[batch_size, - num_true]`. The target classes. Note that this format differs from - the `labels` argument of `nn.softmax_cross_entropy_with_logits`. + num_true]`. The target classes. Note that this format differs from + the `labels` argument of `nn.softmax_cross_entropy_with_logits`. num_sampled: An `int`. The number of classes to randomly sample per batch. num_classes: An `int`. The number of possible classes. num_true: An `int`. The number of target classes per training example. - sampled_values: a tuple of `(sampled_candidates, true_expected_count, - sampled_expected_count)` returned by a `*_candidate_sampler` function. + sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`, + `sampled_expected_count`) returned by a `*_candidate_sampler` function. (if None, we default to `log_uniform_candidate_sampler`) remove_accidental_hits: A `bool`. whether to remove "accidental hits" where a sampled class equals one of the target classes. Default is @@ -857,7 +868,7 @@ def sampled_softmax_loss(weights, biases, inputs, labels, num_sampled, name: A name for the operation (optional). Returns: - A batch_size 1-D tensor of per-example sampled softmax losses. + A `batch_size` 1-D tensor of per-example sampled softmax losses. """ logits, labels = _compute_sampled_logits( diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index f0709dd2a97..1eb8ef4c693 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -22,10 +22,10 @@ from __future__ import print_function import tensorflow.python.platform import numpy as np +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util -from tensorflow.python.framework import types from tensorflow.python.ops import common_shapes from tensorflow.python.ops import gen_nn_ops # pylint: disable=wildcard-import diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py index c8962f60afa..65e28978baa 100644 --- a/tensorflow/python/ops/nn_test.py +++ b/tensorflow/python/ops/nn_test.py @@ -25,8 +25,8 @@ import tensorflow.python.platform import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.python.framework import dtypes from tensorflow.python.framework import test_util -from tensorflow.python.framework import types from tensorflow.python.kernel_tests import gradient_checker as gc from tensorflow.python.ops import array_ops from tensorflow.python.ops import constant_op @@ -50,7 +50,7 @@ class SigmoidCrossEntropyWithLogitsTest(test_util.TensorFlowTestCase): pred = [min(max(p, eps), 1 - eps) for p in pred] return [-z * log(y) - (1 - z) * log(1 - y) for y, z in zip(pred, targets)] - def _Inputs(self, x=None, y=None, dtype=types.float64, sizes=None): + def _Inputs(self, x=None, y=None, dtype=dtypes.float64, sizes=None): x = [-100, -2, -2, 0, 2, 2, 2, 100] if x is None else x y = [0, 0, 1, 0, 0, 1, 0.5, 1] if y is None else y assert len(x) == len(y) @@ -70,7 +70,7 @@ class SigmoidCrossEntropyWithLogitsTest(test_util.TensorFlowTestCase): def testLogisticOutput(self): for use_gpu in [True, False]: with self.test_session(use_gpu=use_gpu): - logits, targets, losses = self._Inputs(dtype=types.float32) + logits, targets, losses = self._Inputs(dtype=dtypes.float32) loss = nn.sigmoid_cross_entropy_with_logits(logits, targets) np_loss = np.array(losses).astype(np.float32) tf_loss = loss.eval() @@ -79,7 +79,7 @@ class SigmoidCrossEntropyWithLogitsTest(test_util.TensorFlowTestCase): def testLogisticOutputMultiDim(self): for use_gpu in [True, False]: with self.test_session(use_gpu=use_gpu): - logits, targets, losses = self._Inputs(dtype=types.float32, + logits, targets, losses = self._Inputs(dtype=dtypes.float32, sizes=[2, 2, 2]) loss = nn.sigmoid_cross_entropy_with_logits(logits, targets) np_loss = np.array(losses).astype(np.float32) @@ -168,9 +168,9 @@ class DeConv2DTest(test_util.TensorFlowTestCase): f_shape = [3, 3, 2, 3] x = constant_op.constant(1.0, shape=x_shape, name="x", - dtype=types.float32) + dtype=dtypes.float32) f = constant_op.constant(1.0, shape=f_shape, name="filter", - dtype=types.float32) + dtype=dtypes.float32) output = nn.deconv2d(x, f, y_shape, strides=strides, padding="SAME") value = output.eval() @@ -205,9 +205,9 @@ class DeConv2DTest(test_util.TensorFlowTestCase): f_shape = [3, 3, 2, 3] x = constant_op.constant(1.0, shape=x_shape, name="x", - dtype=types.float32) + dtype=dtypes.float32) f = constant_op.constant(1.0, shape=f_shape, name="filter", - dtype=types.float32) + dtype=dtypes.float32) output = nn.deconv2d(x, f, y_shape, strides=strides, padding="SAME") value = output.eval() @@ -237,9 +237,9 @@ class DeConv2DTest(test_util.TensorFlowTestCase): f_shape = [3, 3, 2, 3] x = constant_op.constant(1.0, shape=x_shape, name="x", - dtype=types.float32) + dtype=dtypes.float32) f = constant_op.constant(1.0, shape=f_shape, name="filter", - dtype=types.float32) + dtype=dtypes.float32) output = nn.deconv2d(x, f, y_shape, strides=strides, padding="VALID") value = output.eval() @@ -281,8 +281,8 @@ class DeConv2DTest(test_util.TensorFlowTestCase): x_val = np.random.random_sample(x_shape).astype(np.float64) f_val = np.random.random_sample(f_shape).astype(np.float64) with self.test_session(): - x = constant_op.constant(x_val, name="x", dtype=types.float32) - f = constant_op.constant(f_val, name="f", dtype=types.float32) + x = constant_op.constant(x_val, name="x", dtype=dtypes.float32) + f = constant_op.constant(f_val, name="f", dtype=dtypes.float32) output = nn.deconv2d(x, f, y_shape, strides=strides, padding="SAME") err = gc.ComputeGradientError([x, f], [x_shape, f_shape], output, y_shape) print("DeConv gradient err = %g " % err) @@ -355,7 +355,7 @@ class DropoutTest(test_util.TensorFlowTestCase): with self.test_session(): t = constant_op.constant(1.0, shape=[x_dim, y_dim], - dtype=types.float32) + dtype=dtypes.float32) dropout = nn.dropout(t, keep_prob) final_count = 0 self.assertEqual([x_dim, y_dim], dropout.get_shape()) @@ -384,7 +384,7 @@ class DropoutTest(test_util.TensorFlowTestCase): with self.test_session(): t = constant_op.constant(1.0, shape=[x_dim, y_dim], - dtype=types.float32) + dtype=dtypes.float32) dropout = nn.dropout(t, keep_prob, noise_shape=[x_dim, 1]) self.assertEqual([x_dim, y_dim], dropout.get_shape()) final_count = 0 @@ -410,7 +410,7 @@ class DropoutTest(test_util.TensorFlowTestCase): with self.test_session(): t = constant_op.constant(1.0, shape=[x_dim, y_dim], - dtype=types.float32) + dtype=dtypes.float32) dropout = nn.dropout(t, keep_prob, noise_shape=[x_dim, 1]) self.assertEqual([x_dim, y_dim], dropout.get_shape()) for _ in xrange(0, num_iter): @@ -431,8 +431,8 @@ class DropoutTest(test_util.TensorFlowTestCase): with self.test_session(): t = constant_op.constant(1.0, shape=[x_dim, y_dim], - dtype=types.float32) - keep_prob_placeholder = array_ops.placeholder(types.float32) + dtype=dtypes.float32) + keep_prob_placeholder = array_ops.placeholder(dtypes.float32) dropout = nn.dropout(t, keep_prob_placeholder) final_count = 0 self.assertEqual([x_dim, y_dim], dropout.get_shape()) @@ -453,9 +453,9 @@ class DropoutTest(test_util.TensorFlowTestCase): x_dim = 40 y_dim = 30 keep_prob = 0.5 - x = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=types.float32) + x = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) dropout_x = nn.dropout( - x, keep_prob, noise_shape=array_ops.placeholder(types.int32)) + x, keep_prob, noise_shape=array_ops.placeholder(dtypes.int32)) self.assertEqual(x.get_shape(), dropout_x.get_shape()) def testInvalidKeepProb(self): @@ -463,7 +463,7 @@ class DropoutTest(test_util.TensorFlowTestCase): y_dim = 30 t = constant_op.constant(1.0, shape=[x_dim, y_dim], - dtype=types.float32) + dtype=dtypes.float32) with self.assertRaises(ValueError): nn.dropout(t, -1.0) with self.assertRaises(ValueError): @@ -471,9 +471,9 @@ class DropoutTest(test_util.TensorFlowTestCase): with self.assertRaises(ValueError): nn.dropout(t, [0.0, 1.0]) with self.assertRaises(ValueError): - nn.dropout(t, array_ops.placeholder(types.float64)) + nn.dropout(t, array_ops.placeholder(dtypes.float64)) with self.assertRaises(ValueError): - nn.dropout(t, array_ops.placeholder(types.float32, shape=[2])) + nn.dropout(t, array_ops.placeholder(dtypes.float32, shape=[2])) def testShapedDropoutShapeError(self): # Runs shaped dropout and verifies an error is thrown on misshapen noise. @@ -482,7 +482,7 @@ class DropoutTest(test_util.TensorFlowTestCase): keep_prob = 0.5 t = constant_op.constant(1.0, shape=[x_dim, y_dim], - dtype=types.float32) + dtype=dtypes.float32) with self.assertRaises(ValueError): _ = nn.dropout(t, keep_prob, noise_shape=[x_dim, y_dim + 10]) with self.assertRaises(ValueError): @@ -641,7 +641,7 @@ class MomentsTest(test_util.TensorFlowTestCase): assert len(shape) == 4 x_numpy = np.random.normal(size=shape).astype(np.float32) - x = array_ops.placeholder(types.float32, shape=[None] * len(shape)) + x = array_ops.placeholder(dtypes.float32, shape=[None] * len(shape)) axes = [0, 1, 2] if global_norm else [0] mean, var = nn.moments(x, axes) @@ -722,6 +722,7 @@ class ComputeSampledLogitsTest(test_util.TensorFlowTestCase): self._num_classes = 5 self._dim = 10 self._batch_size = 3 + self._num_shards = 3 def _GenerateTestInputs(self): np.random.seed(0) @@ -729,8 +730,11 @@ class ComputeSampledLogitsTest(test_util.TensorFlowTestCase): biases = np.random.randn(self._num_classes).astype(np.float32) hidden_acts = np.random.randn(self._batch_size, self._dim).astype( np.float32) - - return weights, biases, hidden_acts + sharded_weights = [ + weights[[row for row in range(self._num_classes) + if row % self._num_shards == shard]] + for shard in range(self._num_shards)] + return weights, biases, hidden_acts, sharded_weights def _ComputeSampledLogitsNP(self, true_w, true_b, sampled_w, sampled_b, hidden_acts, @@ -763,11 +767,14 @@ class ComputeSampledLogitsTest(test_util.TensorFlowTestCase): subtract_log_q, remove_accidental_hits, name="sampled_loss_TF"): # Should be called from within a `with test_session():` block - weights_tf = constant_op.constant(weights) + if isinstance(weights, list): + weights_tf = [constant_op.constant(shard) for shard in weights] + else: + weights_tf = constant_op.constant(weights) biases_tf = constant_op.constant(biases) hidden_acts_tf = constant_op.constant(hidden_acts, shape=(self._batch_size, self._dim)) - labels_tf = constant_op.constant(labels, dtype=types.int64, + labels_tf = constant_op.constant(labels, dtype=dtypes.int64, shape=(self._batch_size, num_true)) pred_logits_tf, pred_labels_tf = nn._compute_sampled_logits( @@ -780,7 +787,7 @@ class ComputeSampledLogitsTest(test_util.TensorFlowTestCase): def testComputeSampledLogitsShapes(self): # We just check that the shapes of the returned values are correct. - weights, biases, hidden_acts = self._GenerateTestInputs() + weights, biases, hidden_acts, _ = self._GenerateTestInputs() sampled = [1, 0, 2, 3] num_sampled = len(sampled) true_exp = sampled_exp = [1., 1., 1., 1.] @@ -811,7 +818,7 @@ class ComputeSampledLogitsTest(test_util.TensorFlowTestCase): def testComputeSampledLogitsValues(self): # Here we check the actual numerics. - weights, biases, hidden_acts = self._GenerateTestInputs() + weights, biases, hidden_acts, sharded_weights = self._GenerateTestInputs() eps = 1e-3 sampled = [1, 0, 2, 3] num_sampled = len(sampled) @@ -887,6 +894,23 @@ class ComputeSampledLogitsTest(test_util.TensorFlowTestCase): self.assertAllClose(logits_np, logits_tf_val, eps) self.assertAllClose(labels_np, labels_tf_val, eps) + # Test 4: Test 1, with sharded weights + logits_np, labels_np = self._ComputeSampledLogitsNP( + true_w, true_b, sampled_w, sampled_b, hidden_acts, + num_true=num_true_test) + logits_tf, labels_tf = self._ComputeSampledLogitsTF( + sharded_weights, biases, hidden_acts, labels, num_sampled, + self._num_classes, + num_true=num_true_test, + sampled_vals=test_sampled_vals, + subtract_log_q=False, + remove_accidental_hits=False, + name="sampled_loss_test1_num_true%d" % num_true_test) + + logits_tf_val, labels_tf_val = sess.run([logits_tf, labels_tf]) + self.assertAllClose(logits_np, logits_tf_val, eps) + self.assertAllClose(labels_np, labels_tf_val, eps) + def testNCELoss(self): # A simple test to verify the numerics. @@ -898,7 +922,7 @@ class ComputeSampledLogitsTest(test_util.TensorFlowTestCase): pred = np.minimum(np.maximum(pred, eps), 1 - eps) return -targets * np.log(pred) - (1. - targets) * np.log(1. - pred) - weights, biases, hidden_acts = self._GenerateTestInputs() + weights, biases, hidden_acts, sharded_weights = self._GenerateTestInputs() labels = [0, 1, 2] true_w, true_b = weights[labels], biases[labels] sampled = [1, 0, 2, 3] @@ -932,6 +956,17 @@ class ComputeSampledLogitsTest(test_util.TensorFlowTestCase): self.assertAllClose(nce_loss_np, nce_loss_tf.eval(), 1e-4) + # Test with sharded weights + nce_loss_tf = nn.nce_loss( + [constant_op.constant(shard) for shard in sharded_weights], + biases_tf, inputs_tf, labels_tf, + num_sampled=1, + num_classes=self._num_classes, + num_true=1, + sampled_values=test_sampled_vals) + + self.assertAllClose(nce_loss_np, nce_loss_tf.eval(), 1e-4) + def testSampledSoftmaxLoss(self): # A simple test to verify the numerics. @@ -943,7 +978,7 @@ class ComputeSampledLogitsTest(test_util.TensorFlowTestCase): pred = stable_exp_logits / np.sum(stable_exp_logits, 1, keepdims=True) return -np.sum(targets * np.log(pred + 1.0e-20), axis=1) - weights, biases, hidden_acts = self._GenerateTestInputs() + weights, biases, hidden_acts, sharded_weights = self._GenerateTestInputs() labels = [0, 1, 2] true_w, true_b = weights[labels], biases[labels] sampled = [1, 0, 2, 3] @@ -977,6 +1012,19 @@ class ComputeSampledLogitsTest(test_util.TensorFlowTestCase): self.assertAllClose( sampled_softmax_loss_np, sampled_softmax_loss_tf.eval(), 1e-4) + # Test with sharded weights + sampled_softmax_loss_tf = nn.sampled_softmax_loss( + [constant_op.constant(shard) for shard in sharded_weights], + biases_tf, inputs_tf, labels_tf, + num_sampled=1, + num_classes=self._num_classes, + num_true=1, + sampled_values=test_sampled_vals, + remove_accidental_hits=False) + + self.assertAllClose( + sampled_softmax_loss_np, sampled_softmax_loss_tf.eval(), 1e-4) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/python/ops/numerics.py b/tensorflow/python/ops/numerics.py index b2312332f1f..8525da35f5b 100644 --- a/tensorflow/python/ops/numerics.py +++ b/tensorflow/python/ops/numerics.py @@ -19,8 +19,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import types from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -45,7 +45,7 @@ def verify_tensor_all_finite(t, msg, name=None): def add_check_numerics_ops(): - """Connect a check_numerics to every floating point tensor. + """Connect a `check_numerics` to every floating point tensor. `check_numerics` operations themselves are added for each `float` or `double` tensor in the graph. For all ops in the graph, the `check_numerics` op for @@ -62,7 +62,7 @@ def add_check_numerics_ops(): # added, and ops can only be added once its inputs are added. for op in ops.get_default_graph().get_operations(): for output in op.outputs: - if output.dtype in [types.float32, types.float64]: + if output.dtype in [dtypes.float32, dtypes.float64]: message = op.name + ":" + str(output.value_index) with ops.control_dependencies(check_op): check_op = [array_ops.check_numerics(output, message=message)] diff --git a/tensorflow/python/ops/op_def_library.py b/tensorflow/python/ops/op_def_library.py index 562299fdd6e..ad0406d43dc 100644 --- a/tensorflow/python/ops/op_def_library.py +++ b/tensorflow/python/ops/op_def_library.py @@ -19,7 +19,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numbers import six from tensorflow.core.framework import attr_value_pb2 @@ -27,11 +26,12 @@ from tensorflow.core.framework import op_def_pb2 from tensorflow.core.framework import tensor_pb2 from tensorflow.core.framework import tensor_shape_pb2 from tensorflow.core.framework import types_pb2 +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import types as types_lib from tensorflow.python.ops import constant_op from tensorflow.python.platform import logging +from tensorflow.python.util import compat def _Attr(op_def, name): @@ -55,8 +55,8 @@ def _SatisfiesTypeConstraint(dtype, attr_def): if dtype not in allowed_list: raise TypeError( "DataType %s for attr '%s' not in list of allowed values: %s" % - (types_lib.as_dtype(dtype).name, attr_def.name, - ", ".join(types_lib.as_dtype(x).name for x in allowed_list))) + (dtypes.as_dtype(dtype).name, attr_def.name, + ", ".join(dtypes.as_dtype(x).name for x in allowed_list))) def _IsListParameter(arg): @@ -137,7 +137,7 @@ def _Restructure(l, structure): def _MakeFloat(v, arg_name): - if not isinstance(v, numbers.Real): + if not isinstance(v, compat.real_types): raise TypeError("Expected float for argument '%s' not %s." % (arg_name, repr(v))) return float(v) @@ -155,11 +155,10 @@ def _MakeInt(v, arg_name): def _MakeStr(v, arg_name): - if not isinstance(v, six.string_types): + if not isinstance(v, compat.bytes_or_text_types): raise TypeError("Expected string for argument '%s' not %s." % (arg_name, repr(v))) - # TODO(irving): Figure out what to do here from Python 3 - return str(v) # Convert unicode strings to bytes. + return compat.as_bytes(v) # Convert unicode strings to bytes. def _MakeBool(v, arg_name): @@ -171,7 +170,7 @@ def _MakeBool(v, arg_name): def _MakeType(v, attr_def): try: - v = types_lib.as_dtype(v) + v = dtypes.as_dtype(v) except TypeError: raise TypeError("Expected DataType for argument '%s' not %s." % (attr_def.name, repr(v))) @@ -381,7 +380,7 @@ class OpDefLibrary(object): try: values = ops.convert_n_to_tensor_or_indexed_slices( values, name=input_arg.name, - dtype=types_lib.as_dtype(dtype).base_dtype if dtype else None) + dtype=dtypes.as_dtype(dtype).base_dtype if dtype else None) except (TypeError, ValueError): assert dtype is not None, "Should not fail if dtype is None" assert input_arg.number_attr, "Should be number_attr case" @@ -394,11 +393,11 @@ class OpDefLibrary(object): (input_name, op_type_name, observed)) if input_arg.type != types_pb2.DT_INVALID: raise TypeError("%s that do not match expected type %s." % - (prefix, types_lib.as_dtype(dtype).name)) + (prefix, dtypes.as_dtype(dtype).name)) elif input_arg.type_attr in attrs: raise TypeError("%s that do not match type %s inferred from " "earlier arguments." % - (prefix, types_lib.as_dtype(dtype).name)) + (prefix, dtypes.as_dtype(dtype).name)) else: raise TypeError("%s that don't all match." % prefix) @@ -423,11 +422,11 @@ class OpDefLibrary(object): (input_name, op_type_name, observed)) if input_arg.type != types_pb2.DT_INVALID: raise TypeError("%s expected type of %s." % - (prefix, types_lib.as_dtype(input_arg.type).name)) + (prefix, dtypes.as_dtype(input_arg.type).name)) else: raise TypeError( "%s type %s of argument '%s'." % - (prefix, types_lib.as_dtype(attrs[input_arg.type_attr]).name, + (prefix, dtypes.as_dtype(attrs[input_arg.type_attr]).name, inferred_from[input_arg.type_attr])) types = [values.dtype] @@ -501,8 +500,8 @@ class OpDefLibrary(object): "Input '%s' of '%s' Op has type list of %s that does not " "match type list %s of argument '%s'." % (input_name, op_type_name, - ", ".join(types_lib.as_dtype(x).name for x in attr_value), - ", ".join(types_lib.as_dtype(x).name + ", ".join(dtypes.as_dtype(x).name for x in attr_value), + ", ".join(dtypes.as_dtype(x).name for x in attrs[input_arg.type_list_attr]), inferred_from[input_arg.type_list_attr])) else: @@ -567,8 +566,9 @@ class OpDefLibrary(object): if attr_value.s not in attr_def.allowed_values.list.s: raise ValueError( "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." % - (key, op_type_name, attr_value.s, - '", "'.join(attr_def.allowed_values.list.s))) + (key, op_type_name, compat.as_text(attr_value.s), + '", "'.join(map(compat.as_text, + attr_def.allowed_values.list.s)))) elif attr_def.type == "list(string)": attr_value.list.s.extend([_MakeStr(x, key) for x in value]) if attr_def.HasField("allowed_values"): @@ -576,8 +576,9 @@ class OpDefLibrary(object): if x not in attr_def.allowed_values.list.s: raise ValueError( "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." % - (key, op_type_name, x, - '", "'.join(attr_def.allowed_values.list.s))) + (key, op_type_name, compat.as_text(x), + '", "'.join(map(compat.as_text, + attr_def.allowed_values.list.s)))) elif attr_def.type == "int": attr_value.i = _MakeInt(value, key) if attr_def.has_minimum: @@ -640,7 +641,7 @@ class OpDefLibrary(object): types = [arg.type] output_structure.append(None) if arg.is_ref: - types = [types_lib.as_dtype(x).as_ref for x in types] + types = [dtypes.as_dtype(x).as_ref for x in types] output_types.extend(types) if keywords: diff --git a/tensorflow/python/ops/op_def_library_test.py b/tensorflow/python/ops/op_def_library_test.py index 4a8cdbe9a48..4b9d66e8a6a 100644 --- a/tensorflow/python/ops/op_def_library_test.py +++ b/tensorflow/python/ops/op_def_library_test.py @@ -23,10 +23,10 @@ from google.protobuf import text_format from tensorflow.core.framework import op_def_pb2 from tensorflow.core.framework import tensor_shape_pb2 +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import types from tensorflow.python.ops.op_def_library import OpDefLibrary from tensorflow.python.platform import googletest @@ -104,13 +104,13 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): def testNoRegisteredOpFails(self): with self.assertRaises(RuntimeError) as cm: self._lib.apply_op("unknown", g=self._g) - self.assertEqual(cm.exception.message, "Unrecognized Op name unknown") + self.assertEqual(str(cm.exception), "Unrecognized Op name unknown") def testAddOpValidation(self): with self.assertRaises(TypeError) as cm: self._add_op("name: 'MissingTypeAttr' " "input_arg { name: 'a' type_attr: 'T' } ") - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "Inconsistent OpDef for 'MissingTypeAttr', " "missing attr 'T'") @@ -119,13 +119,13 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): "output_arg { name: 'a' type_attr: 'T' } " "attr { name: 'T' type: 'int' }") self.assertEqual( - cm.exception.message, + str(cm.exception), "Attr 'T' of 'BadTypeAttr' used as a type_attr but has type int") with self.assertRaises(TypeError) as cm: self._add_op("name: 'MissingNumberAttr' " "input_arg { name: 'a' type: DT_INT32 number_attr: 'N' } ") - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "Inconsistent OpDef for 'MissingNumberAttr', " "missing attr 'N'") @@ -134,21 +134,21 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): "output_arg { name: 'a' type: DT_INT32 number_attr: 'N' } " "attr { name: 'N' type: 'type' }") self.assertEqual( - cm.exception.message, + str(cm.exception), "Attr 'N' of 'BadNumberAttr' used as a number_attr but has type type") with self.assertRaises(TypeError) as cm: self._add_op("name: 'TwoTypesA' " "input_arg { name: 'a' type: DT_INT32 type_attr: 'T' } " "attr { name: 'T' type: 'type' }") - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "Arg 'a' of 'TwoTypesA' must have one type field not 2") with self.assertRaises(TypeError) as cm: self._add_op("name: 'TwoTypesB' " "input_arg { name: 'a' type: DT_INT32 type_list_attr: 'T' } " "attr { name: 'T' type: 'list(type)' }") - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "Arg 'a' of 'TwoTypesB' must have one type field not 2") with self.assertRaises(TypeError) as cm: @@ -157,17 +157,17 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): "type_list_attr: 'U' } " "attr { name: 'T' type: 'type' } " "attr { name: 'U' type: 'list(type)' }") - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "Arg 'a' of 'ThreeTypes' must have one type field not 3") with self.assertRaises(TypeError) as cm: self._add_op("name: 'NoTypes' output_arg { name: 'a' } ") - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "Arg 'a' of 'NoTypes' must have one type field not 0") def testSimple(self): out = self._lib.apply_op("Simple", a=3) - self.assertEquals(types.float32, out.dtype) + self.assertEqual(dtypes.float32, out.dtype) self.assertProtoEquals(""" name: 'Simple' op: 'Simple' input: 'Simple/a' """, out.op.node_def) @@ -190,39 +190,39 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): def testSimpleFailures(self): with self.assertRaises(TypeError) as cm: self._lib.apply_op("Simple", a="Bad string") - self.assertEqual(cm.exception.message, - "Expected int32, got 'Bad string' instead.") + self.assertEqual(str(cm.exception), + "Expected int32, got 'Bad string' of type 'str' instead.") with self.assertRaises(TypeError) as cm: - self._lib.apply_op("Simple", a=self.Tensor(types.string)) - self.assertEqual(cm.exception.message, + self._lib.apply_op("Simple", a=self.Tensor(dtypes.string)) + self.assertEqual(str(cm.exception), "Input 'a' of 'Simple' Op has type string " "that does not match expected type of int32.") with self.assertRaises(TypeError) as cm: self._lib.apply_op("Simple", a=6, extra="bogus") - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "apply_op() got unexpected keyword arguments: extra") with self.assertRaises(TypeError) as cm: self._lib.apply_op("Simple", a=6, extra1="bogus", extra2="also_bogus") - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "apply_op() got unexpected keyword arguments: extra1, " "extra2") with self.assertRaises(TypeError) as cm: self._lib.apply_op("Simple") - self.assertEqual(cm.exception.message, "No argument for input a") + self.assertEqual(str(cm.exception), "No argument for input a") with self.assertRaises(TypeError) as cm: self._lib.apply_op("Simple", wrong=7) - self.assertEqual(cm.exception.message, "No argument for input a") + self.assertEqual(str(cm.exception), "No argument for input a") with self.assertRaises(TypeError) as cm: - self._lib.apply_op("Simple", a=[self.Tensor(types.int32)]) - self.assertStartsWith( - cm.exception.message, - "Expected int32, got list containing Tensors instead.") + self._lib.apply_op("Simple", a=[self.Tensor(dtypes.int32)]) + self.assertStartsWith(str(cm.exception), + "Expected int32, got list containing Tensors of type " + "'_Message' instead.") def testReservedInput(self): self._add_op("name: 'ReservedInput' " @@ -239,34 +239,34 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): "attr { name: 'T' type: 'type' }") out = self._lib.apply_op("Polymorphic", a=7, name="p") - self.assertEquals(types.int32, out.dtype) + self.assertEqual(dtypes.int32, out.dtype) self.assertProtoEquals(""" name: 'p' op: 'Polymorphic' input: 'p/a' attr { key: 'T' value { type: DT_INT32 } } """, out.op.node_def) out = self._lib.apply_op("Polymorphic", a="s", name="q") - self.assertEquals(types.string, out.dtype) + self.assertEqual(dtypes.string, out.dtype) self.assertProtoEquals(""" name: 'q' op: 'Polymorphic' input: 'q/a' attr { key: 'T' value { type: DT_STRING } } """, out.op.node_def) out = self._lib.apply_op("Polymorphic", a=["s", "t", "u"], name="r") - self.assertEquals(types.string, out.dtype) + self.assertEqual(dtypes.string, out.dtype) self.assertProtoEquals(""" name: 'r' op: 'Polymorphic' input: 'r/a' attr { key: 'T' value { type: DT_STRING } } """, out.op.node_def) with self.assertRaises(TypeError) as cm: - self._lib.apply_op("Polymorphic", a="s", T=types.string) - self.assertEqual(cm.exception.message, + self._lib.apply_op("Polymorphic", a="s", T=dtypes.string) + self.assertEqual(str(cm.exception), "Should not specify value for inferred attr 'T'.") with self.assertRaises(TypeError) as cm: - self._lib.apply_op("Polymorphic", a=[self.Tensor(types.bool)]) - self.assertEqual(cm.exception.message, + self._lib.apply_op("Polymorphic", a=[self.Tensor(dtypes.bool)]) + self.assertEqual(str(cm.exception), "List of Tensors when single Tensor expected") def testPolymorphicOut(self): @@ -274,15 +274,15 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): "output_arg { name: 'out' type_attr: 'T' } " "attr { name: 'T' type: 'type' }") - out = self._lib.apply_op("PolymorphicOut", T=types.int32, name="p") - self.assertEquals(types.int32, out.dtype) + out = self._lib.apply_op("PolymorphicOut", T=dtypes.int32, name="p") + self.assertEqual(dtypes.int32, out.dtype) self.assertProtoEquals(""" name: 'p' op: 'PolymorphicOut' attr { key: 'T' value { type: DT_INT32 } } """, out.op.node_def) - out = self._lib.apply_op("PolymorphicOut", T=types.bool, name="q") - self.assertEquals(types.bool, out.dtype) + out = self._lib.apply_op("PolymorphicOut", T=dtypes.bool, name="q") + self.assertEqual(dtypes.bool, out.dtype) self.assertProtoEquals(""" name: 'q' op: 'PolymorphicOut' attr { key: 'T' value { type: DT_BOOL } } @@ -290,12 +290,12 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): with self.assertRaises(TypeError) as cm: self._lib.apply_op("PolymorphicOut") - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "No argument for attr T") with self.assertRaises(TypeError) as cm: self._lib.apply_op("PolymorphicOut", T=None) - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "Expected DataType for argument 'T' not None.") def testPolymorphicDefaultOut(self): @@ -305,15 +305,15 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): " default_value { type: DT_STRING } }") out = self._lib.apply_op("PolymorphicDefaultOut", T=None, name="p") - self.assertEquals(types.string, out.dtype) + self.assertEqual(dtypes.string, out.dtype) self.assertProtoEquals(""" name: 'p' op: 'PolymorphicDefaultOut' attr { key: 'T' value { type: DT_STRING } } """, out.op.node_def) - out = self._lib.apply_op("PolymorphicDefaultOut", T=types.bool, + out = self._lib.apply_op("PolymorphicDefaultOut", T=dtypes.bool, name="q") - self.assertEquals(types.bool, out.dtype) + self.assertEqual(dtypes.bool, out.dtype) self.assertProtoEquals(""" name: 'q' op: 'PolymorphicDefaultOut' attr { key: 'T' value { type: DT_BOOL } } @@ -327,14 +327,14 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): "attr { name: 'T' type: 'type' }") out = self._lib.apply_op("Binary", a=8, b=9, name="b") - self.assertEquals(types.int32, out.dtype) + self.assertEqual(dtypes.int32, out.dtype) self.assertProtoEquals(""" name: 'b' op: 'Binary' input: 'b/a' input: 'b/b' attr { key: 'T' value { type: DT_INT32 } } """, out.op.node_def) out = self._lib.apply_op("Binary", a="left", b="right", name="c") - self.assertEquals(types.string, out.dtype) + self.assertEqual(dtypes.string, out.dtype) self.assertProtoEquals(""" name: 'c' op: 'Binary' input: 'c/a' input: 'c/b' attr { key: 'T' value { type: DT_STRING } } @@ -342,13 +342,13 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): with self.assertRaises(TypeError) as cm: self._lib.apply_op("Binary", a="left", b=12) - self.assertEqual(cm.exception.message, - "Expected string, got 12 instead.") + self.assertEqual(str(cm.exception), + "Expected string, got 12 of type 'int' instead.") with self.assertRaises(TypeError) as cm: - self._lib.apply_op("Binary", a=self.Tensor(types.string), - b=self.Tensor(types.int32)) - self.assertEqual(cm.exception.message, + self._lib.apply_op("Binary", a=self.Tensor(dtypes.string), + b=self.Tensor(dtypes.int32)) + self.assertEqual(str(cm.exception), "Input 'b' of 'Binary' Op has type int32 " "that does not match type string of argument 'a'.") @@ -360,14 +360,14 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): " type: DT_STRING type: DT_BOOL } } }") out = self._lib.apply_op("Restrict", a="foo", name="g") - self.assertEquals(types.string, out.dtype) + self.assertEqual(dtypes.string, out.dtype) self.assertProtoEquals(""" name: 'g' op: 'Restrict' input: 'g/a' attr { key: 'T' value { type: DT_STRING } } """, out.op.node_def) out = self._lib.apply_op("Restrict", a=True, name="h") - self.assertEquals(types.bool, out.dtype) + self.assertEqual(dtypes.bool, out.dtype) self.assertProtoEquals(""" name: 'h' op: 'Restrict' input: 'h/a' attr { key: 'T' value { type: DT_BOOL } } @@ -375,7 +375,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): with self.assertRaises(TypeError) as cm: self._lib.apply_op("Restrict", a=17) - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "DataType int32 for attr 'T' " "not in list of allowed values: " "string, bool") @@ -404,7 +404,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): with self.assertRaises(TypeError) as cm: self._lib.apply_op("TypeList", a=17) - self.assertStartsWith(cm.exception.message, + self.assertStartsWith(str(cm.exception), "Expected list for 'a' " "argument to 'TypeList' Op, not ") @@ -429,7 +429,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): with self.assertRaises(TypeError) as cm: self._lib.apply_op("TypeListTwice", a=["foo", True], b=["bar", 6]) - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "Input 'b' of 'TypeListTwice' Op has type list of " "string, int32 that does not match type list " "string, bool of argument 'a'.") @@ -439,18 +439,18 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): "output_arg { name: 'out' type_list_attr: 'T' } " "attr { name: 'T' type: 'list(type)' }") - out, = self._lib.apply_op("OutTypeList", T=[types.float32], name="x") - self.assertEquals(types.float32, out.dtype) + out, = self._lib.apply_op("OutTypeList", T=[dtypes.float32], name="x") + self.assertEqual(dtypes.float32, out.dtype) self.assertProtoEquals(""" name: 'x' op: 'OutTypeList' attr { key: 'T' value { list { type: DT_FLOAT } } } """, out.op.node_def) out1, out2 = self._lib.apply_op("OutTypeList", - T=[types.int32, types.bool], + T=[dtypes.int32, dtypes.bool], name="w") - self.assertEquals(types.int32, out1.dtype) - self.assertEquals(types.bool, out2.dtype) + self.assertEqual(dtypes.int32, out1.dtype) + self.assertEqual(dtypes.bool, out2.dtype) self.assertProtoEquals(""" name: 'w' op: 'OutTypeList' attr { key: 'T' value { list { type: DT_INT32 type: DT_BOOL } } } @@ -460,8 +460,8 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): self.assertEqual([], out) with self.assertRaises(TypeError) as cm: - self._lib.apply_op("OutTypeList", T=types.int32) - self.assertEqual(cm.exception.message, "Expected list for attr T") + self._lib.apply_op("OutTypeList", T=dtypes.int32) + self.assertEqual(str(cm.exception), "Expected list for attr T") def testTypeListRestrict(self): self._add_op("name: 'TypeListRestrict' " @@ -477,7 +477,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): with self.assertRaises(TypeError) as cm: self._lib.apply_op("TypeListRestrict", a=[True, 12]) - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "DataType int32 for attr 'T' " "not in list of allowed values: string, bool") @@ -488,10 +488,10 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): " type: DT_STRING type: DT_BOOL } } }") out1, out2 = self._lib.apply_op("OutTypeListRestrict", - t=[types.bool, types.string], + t=[dtypes.bool, dtypes.string], name="u") - self.assertEquals(types.bool, out1.dtype) - self.assertEquals(types.string, out2.dtype) + self.assertEqual(dtypes.bool, out1.dtype) + self.assertEqual(dtypes.string, out2.dtype) self.assertProtoEquals(""" name: 'u' op: 'OutTypeListRestrict' attr { key: 't' value { list { type: DT_BOOL type: DT_STRING } } } @@ -499,8 +499,8 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): with self.assertRaises(TypeError) as cm: self._lib.apply_op("OutTypeListRestrict", - t=[types.string, types.int32]) - self.assertEqual(cm.exception.message, + t=[dtypes.string, dtypes.int32]) + self.assertEqual(str(cm.exception), "DataType int32 for attr 't' " "not in list of allowed values: string, bool") @@ -518,22 +518,22 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): with self.assertRaises(TypeError) as cm: self._lib.apply_op("Attr", a="bad") - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "Expected int for argument 'a' not 'bad'.") with self.assertRaises(TypeError) as cm: self._lib.apply_op("Attr", a=[12]) - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "Expected int for argument 'a' not [12].") with self.assertRaises(TypeError) as cm: self._lib.apply_op("Attr", a=None) - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "Expected int for argument 'a' not None.") with self.assertRaises(TypeError) as cm: self._lib.apply_op("Attr") - self.assertEqual(cm.exception.message, "No argument for attr a") + self.assertEqual(str(cm.exception), "No argument for attr a") def testAttrFloat(self): self._add_op("name: 'AttrFloat' attr { name: 'a' type: 'float' }") @@ -550,7 +550,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): with self.assertRaises(TypeError) as cm: self._lib.apply_op("AttrFloat", a="bad") - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "Expected float for argument 'a' not 'bad'.") def testAttrBool(self): @@ -568,17 +568,17 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): with self.assertRaises(TypeError) as cm: self._lib.apply_op("AttrBool", a=0) - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "Expected bool for argument 'a' not 0.") with self.assertRaises(TypeError) as cm: self._lib.apply_op("AttrBool", a=1) - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "Expected bool for argument 'a' not 1.") with self.assertRaises(TypeError) as cm: self._lib.apply_op("AttrBool", a=[]) - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "Expected bool for argument 'a' not [].") def testAttrBoolList(self): @@ -597,7 +597,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): with self.assertRaises(TypeError) as cm: self._lib.apply_op("AttrBoolList", a=[0]) - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "Expected bool for argument 'a' not 0.") def testAttrMin(self): @@ -610,7 +610,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): with self.assertRaises(ValueError) as cm: self._lib.apply_op("AttrMin", a=2) - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "Attr 'a' of 'AttrMin' Op passed 2 less than minimum 5.") def testAttrListMin(self): @@ -625,7 +625,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): with self.assertRaises(ValueError) as cm: self._lib.apply_op("AttrListMin", a=[17]) - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "Attr 'a' of 'AttrListMin' Op " "passed list of length 1 less than minimum 2.") @@ -641,7 +641,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): with self.assertRaises(ValueError) as cm: self._lib.apply_op("AttrEnum", a="invalid") - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), 'Attr \'a\' of \'AttrEnum\' Op ' 'passed string \'invalid\' not in: ' '"apples", "oranges".') @@ -659,7 +659,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): with self.assertRaises(ValueError) as cm: self._lib.apply_op("AttrEnumList", a=["apples", "invalid", "oranges"]) - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), 'Attr \'a\' of \'AttrEnumList\' Op ' 'passed string \'invalid\' not ' 'in: "apples", "oranges".') @@ -705,7 +705,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): # TODO(josh11b): Re-enable this test once we stop promoting scalars to shapes. # with self.assertRaises(TypeError) as cm: # self._lib.apply_op("AttrShape", a=5) - # self.assertEqual(cm.exception.message, + # self.assertEqual(str(cm.exception), # "Don't know how to convert 5 to a TensorShapeProto for " # "argument 'a'") @@ -817,42 +817,42 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): with self.assertRaises(TypeError) as cm: self._lib.apply_op("NIntsIn", a=["foo", "bar"]) - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "Tensors in list passed to 'a' of 'NIntsIn' Op have types " "[string, string] that do not match expected type int32.") with self.assertRaises(TypeError) as cm: - self._lib.apply_op("NIntsIn", a=[self.Tensor(types.string), - self.Tensor(types.string)]) - self.assertEqual(cm.exception.message, + self._lib.apply_op("NIntsIn", a=[self.Tensor(dtypes.string), + self.Tensor(dtypes.string)]) + self.assertEqual(str(cm.exception), "Tensors in list passed to 'a' of 'NIntsIn' Op have " "types [string, string] that do not match expected type " "int32.") with self.assertRaises(ValueError) as cm: self._lib.apply_op("NIntsIn", a=[99]) - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "List argument 'a' to 'NIntsIn' Op " "with length 1 shorter than " "minimum length 2.") with self.assertRaises(TypeError) as cm: self._lib.apply_op("NIntsIn", a=[38, "bar"]) - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "Tensors in list passed to 'a' of 'NIntsIn' Op have types " "[int32, string] that do not match expected type int32.") with self.assertRaises(TypeError) as cm: - self._lib.apply_op("NIntsIn", a=[self.Tensor(types.int32), - self.Tensor(types.string)]) - self.assertEqual(cm.exception.message, + self._lib.apply_op("NIntsIn", a=[self.Tensor(dtypes.int32), + self.Tensor(dtypes.string)]) + self.assertEqual(str(cm.exception), "Tensors in list passed to 'a' of 'NIntsIn' Op " "have types [int32, string] that do not match expected " "type int32.") with self.assertRaises(TypeError) as cm: self._lib.apply_op("NIntsIn", a=17) - self.assertStartsWith(cm.exception.message, + self.assertStartsWith(str(cm.exception), "Expected list for 'a' argument " "to 'NIntsIn' Op, not ") @@ -885,7 +885,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): """, op.node_def) op = self._lib.apply_op("NPolymorphicIn", - a=[1, self.Tensor(types.float32, name="x")], + a=[1, self.Tensor(dtypes.float32, name="x")], name="q") self.assertProtoEquals(""" name: 'q' op: 'NPolymorphicIn' input: 'q/a_0' input: 'x' @@ -895,33 +895,33 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): with self.assertRaises(ValueError) as cm: self._lib.apply_op("NPolymorphicIn", a=[99]) - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "List argument 'a' to 'NPolymorphicIn' Op with length 1 " "shorter than minimum length 2.") with self.assertRaises(TypeError) as cm: self._lib.apply_op("NPolymorphicIn", a=[38, "bar"]) - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "All tensors passed to 'a' of 'NPolymorphicIn' " "Op must have the same type.") with self.assertRaises(TypeError) as cm: self._lib.apply_op("NPolymorphicIn", - a=[38, self.Tensor(types.string)]) - self.assertEqual(cm.exception.message, + a=[38, self.Tensor(dtypes.string)]) + self.assertEqual(str(cm.exception), "Tensors in list passed to 'a' of 'NPolymorphicIn' Op " "have types [int32, string] that don't all match.") with self.assertRaises(TypeError) as cm: self._lib.apply_op("NPolymorphicIn", - a=["abcd", self.Tensor(types.int32)]) - self.assertEqual(cm.exception.message, + a=["abcd", self.Tensor(dtypes.int32)]) + self.assertEqual(str(cm.exception), "Tensors in list passed to 'a' of 'NPolymorphicIn' Op " "have types [string, int32] that don't all match.") with self.assertRaises(TypeError) as cm: self._lib.apply_op("NPolymorphicIn", a=17) - self.assertStartsWith(cm.exception.message, + self.assertStartsWith(str(cm.exception), "Expected list for 'a' argument " "to 'NPolymorphicIn' Op, not ") @@ -951,7 +951,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): with self.assertRaises(TypeError) as cm: self._lib.apply_op("NPolymorphicRestrictIn", a=[1, 2]) - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "DataType int32 for attr 'T' " "not in list of allowed values: string, bool") @@ -975,7 +975,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): with self.assertRaises(ValueError) as cm: self._lib.apply_op("NInTwice", a=[1, 2, 3], b=["too short"]) - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "List argument 'b' to 'NInTwice' Op " "with length 1 must match " "length 3 of argument 'a'.") @@ -997,23 +997,23 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): with self.assertRaises(ValueError) as cm: self._lib.apply_op("NInPolymorphicTwice", a=[1, 2, 3], b=[5]) - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "List argument 'b' to 'NInPolymorphicTwice' Op " "with length 1 " "must match length 3 of argument 'a'.") with self.assertRaises(TypeError) as cm: self._lib.apply_op("NInPolymorphicTwice", a=[1, 2], b=["one", "two"]) - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "Tensors in list passed to 'b' of 'NInPolymorphicTwice' " "Op have types [string, string] that do not match type " "int32 inferred from earlier arguments.") with self.assertRaises(TypeError) as cm: self._lib.apply_op("NInPolymorphicTwice", - a=[self.Tensor(types.int32)], - b=[self.Tensor(types.string)]) - self.assertEqual(cm.exception.message, + a=[self.Tensor(dtypes.int32)], + b=[self.Tensor(dtypes.string)]) + self.assertEqual(str(cm.exception), "Tensors in list passed to 'b' of " "'NInPolymorphicTwice' Op have types [string] that do not " "match type int32 inferred from earlier arguments.") @@ -1046,8 +1046,8 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): """, op.node_def) op = self._lib.apply_op("NInTwoTypeVariables", - a=[self.Tensor(types.int32, name="q")], - b=[self.Tensor(types.string, name="r")], + a=[self.Tensor(dtypes.int32, name="q")], + b=[self.Tensor(dtypes.string, name="r")], name="p") self.assertProtoEquals(""" name: 'p' op: 'NInTwoTypeVariables' input: 'q' input: 'r' @@ -1058,7 +1058,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): with self.assertRaises(ValueError) as cm: self._lib.apply_op("NInTwoTypeVariables", a=[1, 2, 3], b=["5"]) - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "List argument 'b' to 'NInTwoTypeVariables' Op " "with length 1 " "must match length 3 of argument 'a'.") @@ -1090,22 +1090,22 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): with self.assertRaises(TypeError) as cm: self._lib.apply_op("InPolymorphicTwice", a=[], b=[3, 4, 5]) - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "Don't know how to infer type variable from empty input " "list passed to input 'a' of 'InPolymorphicTwice' Op.") with self.assertRaises(TypeError) as cm: self._lib.apply_op("InPolymorphicTwice", a=[1, 2], b=["one", "two"]) - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "Tensors in list passed to 'b' of 'InPolymorphicTwice' Op " "have types [string, string] that do not match type int32 " "inferred from earlier arguments.") with self.assertRaises(TypeError) as cm: self._lib.apply_op("InPolymorphicTwice", - a=[self.Tensor(types.int32)], - b=[self.Tensor(types.string)]) - self.assertEqual(cm.exception.message, + a=[self.Tensor(dtypes.int32)], + b=[self.Tensor(dtypes.string)]) + self.assertEqual(str(cm.exception), "Tensors in list passed to 'b' of 'InPolymorphicTwice' " "Op have types [string] that do not match type int32 " "inferred from earlier arguments.") @@ -1116,31 +1116,31 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): "attr { name: 'N' type: 'int' has_minimum: true minimum: 2 }") out1, out2 = self._lib.apply_op("NIntsOut", N=2, name="n") - self.assertEquals(types.int32, out1.dtype) - self.assertEquals(types.int32, out2.dtype) + self.assertEqual(dtypes.int32, out1.dtype) + self.assertEqual(dtypes.int32, out2.dtype) self.assertProtoEquals(""" name: 'n' op: 'NIntsOut' attr { key: 'N' value { i: 2 } } """, out1.op.node_def) out1, out2, out3, out4, out5 = self._lib.apply_op( "NIntsOut", N=5, name="o") - self.assertEquals(types.int32, out1.dtype) - self.assertEquals(types.int32, out2.dtype) - self.assertEquals(types.int32, out3.dtype) - self.assertEquals(types.int32, out4.dtype) - self.assertEquals(types.int32, out5.dtype) + self.assertEqual(dtypes.int32, out1.dtype) + self.assertEqual(dtypes.int32, out2.dtype) + self.assertEqual(dtypes.int32, out3.dtype) + self.assertEqual(dtypes.int32, out4.dtype) + self.assertEqual(dtypes.int32, out5.dtype) self.assertProtoEquals(""" name: 'o' op: 'NIntsOut' attr { key: 'N' value { i: 5 } } """, out5.op.node_def) with self.assertRaises(ValueError) as cm: self._lib.apply_op("NIntsOut", N=1) - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "Attr 'N' of 'NIntsOut' Op passed 1 less than minimum 2.") with self.assertRaises(TypeError) as cm: self._lib.apply_op("NIntsOut", N=[3]) - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "Expected int for argument 'N' not [3].") def testNIntsOutDefault(self): @@ -1151,16 +1151,16 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): out1, out2, out3 = self._lib.apply_op( "NIntsOutDefault", N=None, name="z") - self.assertEquals(types.int32, out1.dtype) - self.assertEquals(types.int32, out2.dtype) - self.assertEquals(types.int32, out3.dtype) + self.assertEqual(dtypes.int32, out1.dtype) + self.assertEqual(dtypes.int32, out2.dtype) + self.assertEqual(dtypes.int32, out3.dtype) self.assertProtoEquals(""" name: 'z' op: 'NIntsOutDefault' attr { key: 'N' value { i: 3 } } """, out1.op.node_def) out1, out2 = self._lib.apply_op("NIntsOutDefault", N=2, name="y") - self.assertEquals(types.int32, out1.dtype) - self.assertEquals(types.int32, out2.dtype) + self.assertEqual(dtypes.int32, out1.dtype) + self.assertEqual(dtypes.int32, out2.dtype) self.assertProtoEquals(""" name: 'y' op: 'NIntsOutDefault' attr { key: 'N' value { i: 2 } } """, out2.op.node_def) @@ -1172,9 +1172,9 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): "attr { name: 'N' type: 'int' has_minimum: true minimum: 2 }") out1, out2 = self._lib.apply_op("NPolymorphicOut", N=2, - T=types.int32, name="n") - self.assertEquals(types.int32, out1.dtype) - self.assertEquals(types.int32, out2.dtype) + T=dtypes.int32, name="n") + self.assertEqual(dtypes.int32, out1.dtype) + self.assertEqual(dtypes.int32, out2.dtype) self.assertProtoEquals(""" name: 'n' op: 'NPolymorphicOut' attr { key: 'T' value { type: DT_INT32 } } @@ -1182,10 +1182,10 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): """, out1.op.node_def) out1, out2, out3 = self._lib.apply_op( - "NPolymorphicOut", T=types.string, N=3, name="o") - self.assertEquals(types.string, out1.dtype) - self.assertEquals(types.string, out2.dtype) - self.assertEquals(types.string, out3.dtype) + "NPolymorphicOut", T=dtypes.string, N=3, name="o") + self.assertEqual(dtypes.string, out1.dtype) + self.assertEqual(dtypes.string, out2.dtype) + self.assertEqual(dtypes.string, out3.dtype) self.assertProtoEquals(""" name: 'o' op: 'NPolymorphicOut' attr { key: 'T' value { type: DT_STRING } } @@ -1193,15 +1193,15 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): """, out3.op.node_def) with self.assertRaises(ValueError) as cm: - self._lib.apply_op("NPolymorphicOut", N=1, T=types.string) - self.assertEqual(cm.exception.message, + self._lib.apply_op("NPolymorphicOut", N=1, T=dtypes.string) + self.assertEqual(str(cm.exception), "Attr 'N' of 'NPolymorphicOut' Op " "passed 1 less than minimum 2.") with self.assertRaises(TypeError) as cm: - self._lib.apply_op("NPolymorphicOut", N=3, T=[types.string]) + self._lib.apply_op("NPolymorphicOut", N=3, T=[dtypes.string]) self.assertEqual( - cm.exception.message, + str(cm.exception), "Expected DataType for argument 'T' not [tf.string].") def testNPolymorphicOutDefault(self): @@ -1214,8 +1214,8 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): out1, out2 = self._lib.apply_op( "NPolymorphicOutDefault", N=None, T=None, name="r") - self.assertEquals(types.bool, out1.dtype) - self.assertEquals(types.bool, out2.dtype) + self.assertEqual(dtypes.bool, out1.dtype) + self.assertEqual(dtypes.bool, out2.dtype) self.assertProtoEquals(""" name: 'r' op: 'NPolymorphicOutDefault' attr { key: 'T' value { type: DT_BOOL } } @@ -1224,9 +1224,9 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): out1, out2, out3 = self._lib.apply_op( "NPolymorphicOutDefault", N=3, T=None, name="s") - self.assertEquals(types.bool, out1.dtype) - self.assertEquals(types.bool, out2.dtype) - self.assertEquals(types.bool, out3.dtype) + self.assertEqual(dtypes.bool, out1.dtype) + self.assertEqual(dtypes.bool, out2.dtype) + self.assertEqual(dtypes.bool, out3.dtype) self.assertProtoEquals(""" name: 's' op: 'NPolymorphicOutDefault' attr { key: 'T' value { type: DT_BOOL } } @@ -1234,9 +1234,9 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): """, out1.op.node_def) out1, out2 = self._lib.apply_op( - "NPolymorphicOutDefault", N=None, T=types.int32, name="t") - self.assertEquals(types.int32, out1.dtype) - self.assertEquals(types.int32, out2.dtype) + "NPolymorphicOutDefault", N=None, T=dtypes.int32, name="t") + self.assertEqual(dtypes.int32, out1.dtype) + self.assertEqual(dtypes.int32, out2.dtype) self.assertProtoEquals(""" name: 't' op: 'NPolymorphicOutDefault' attr { key: 'T' value { type: DT_INT32 } } @@ -1244,10 +1244,10 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): """, out1.op.node_def) out1, out2, out3 = self._lib.apply_op( - "NPolymorphicOutDefault", N=3, T=types.int32, name="u") - self.assertEquals(types.int32, out1.dtype) - self.assertEquals(types.int32, out2.dtype) - self.assertEquals(types.int32, out3.dtype) + "NPolymorphicOutDefault", N=3, T=dtypes.int32, name="u") + self.assertEqual(dtypes.int32, out1.dtype) + self.assertEqual(dtypes.int32, out2.dtype) + self.assertEqual(dtypes.int32, out3.dtype) self.assertProtoEquals(""" name: 'u' op: 'NPolymorphicOutDefault' attr { key: 'T' value { type: DT_INT32 } } @@ -1262,10 +1262,10 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): "attr { name: 'N' type: 'int' has_minimum: true minimum: 2 }") out1, out2, out3 = self._lib.apply_op( - "NPolymorphicRestrictOut", N=3, T=types.bool, name="u") - self.assertEquals(types.bool, out1.dtype) - self.assertEquals(types.bool, out2.dtype) - self.assertEquals(types.bool, out3.dtype) + "NPolymorphicRestrictOut", N=3, T=dtypes.bool, name="u") + self.assertEqual(dtypes.bool, out1.dtype) + self.assertEqual(dtypes.bool, out2.dtype) + self.assertEqual(dtypes.bool, out3.dtype) self.assertProtoEquals(""" name: 'u' op: 'NPolymorphicRestrictOut' attr { key: 'T' value { type: DT_BOOL } } @@ -1273,8 +1273,8 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): """, out1.op.node_def) with self.assertRaises(TypeError) as cm: - self._lib.apply_op("NPolymorphicRestrictOut", N=2, T=types.int32) - self.assertEqual(cm.exception.message, + self._lib.apply_op("NPolymorphicRestrictOut", N=2, T=dtypes.int32) + self.assertEqual(str(cm.exception), "DataType int32 for attr 'T' " "not in list of allowed values: string, bool") @@ -1286,8 +1286,8 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): "output_arg { name: 'a' type_attr: 'T' is_ref: true } " "attr { name: 'T' type: 'type' } ") - out = self._lib.apply_op("RefOut", T=types.bool, name="o") - self.assertEquals(types.bool_ref, out.dtype) + out = self._lib.apply_op("RefOut", T=dtypes.bool, name="o") + self.assertEqual(dtypes.bool_ref, out.dtype) self.assertProtoEquals(""" name: 'o' op: 'RefOut' attr { key: 'T' value { type: DT_BOOL } } @@ -1300,7 +1300,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): """, op.node_def) # Can pass ref to non-ref input. - out = self._lib.apply_op("RefOut", T=types.int32, name="r") + out = self._lib.apply_op("RefOut", T=dtypes.int32, name="r") out = self._lib.apply_op("Simple", a=out, name="s") self.assertProtoEquals(""" name: 's' op: 'Simple' input: 'r' @@ -1309,7 +1309,7 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): # Can't pass non-ref to ref input. with self.assertRaises(TypeError) as cm: self._lib.apply_op("RefIn", a=2) - self.assertEqual(cm.exception.message, + self.assertEqual(str(cm.exception), "Input 'a' of 'RefIn' Op requires l-value input") def testSpecifyDevice(self): @@ -1340,9 +1340,9 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): a, b = self._lib.apply_op("MixedStruct", n_a=n_a) self.assertTrue(isinstance(a, list)) self.assertEqual(n_a, len(a)) - self.assertTrue(all(x.dtype == types.int32 for x in a)) + self.assertTrue(all(x.dtype == dtypes.int32 for x in a)) self.assertTrue(isinstance(b, ops.Tensor)) - self.assertEqual(types.float32, b.dtype) + self.assertEqual(dtypes.float32, b.dtype) def testStructuredOutputMultipleLists(self): self._add_op("name: 'ComplexStruct' " @@ -1355,15 +1355,15 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase): for n_a in [0, 1, 3]: for n_b in [0, 1, 3]: for t_c in [[], - [types.int32], - [types.int32, types.float32]]: + [dtypes.int32], + [dtypes.int32, dtypes.float32]]: a, b, c = self._lib.apply_op("ComplexStruct", n_a=n_a, n_b=n_b, t_c=t_c) self.assertEqual(n_a, len(a)) - self.assertTrue(all(x.dtype == types.int32 for x in a)) + self.assertTrue(all(x.dtype == dtypes.int32 for x in a)) self.assertEqual(n_b, len(b)) - self.assertTrue(all(x.dtype == types.int64 for x in b)) + self.assertTrue(all(x.dtype == dtypes.int64 for x in b)) self.assertEqual(t_c, [x.dtype for x in c]) @@ -1387,19 +1387,19 @@ class OpDefLibraryGraphTest(test_util.TensorFlowTestCase): def testNoGraph(self): out = self._lib.apply_op("Simple", a=3) - self.assertEquals(out.graph, ops.get_default_graph()) + self.assertEqual(out.graph, ops.get_default_graph()) def testDefaultGraph(self): with self._g.as_default(): out = self._lib.apply_op("Simple", a=3) - self.assertEquals(out.graph, self._g) + self.assertEqual(out.graph, self._g) def testIgnoreDefaultGraphWithGraphArgument(self): default_g = ops.Graph() with default_g.as_default(): out = self._lib.apply_op("Simple", a=3, g=self._g) - self.assertEquals(ops.get_default_graph(), default_g) - self.assertEquals(out.graph, self._g) + self.assertEqual(ops.get_default_graph(), default_g) + self.assertEqual(out.graph, self._g) def testDifferentGraphFails(self): a = self._lib.apply_op("Simple", a=3, g=self._g) @@ -1407,7 +1407,7 @@ class OpDefLibraryGraphTest(test_util.TensorFlowTestCase): b = self._lib.apply_op("Simple", a=4, g=other_g) with self.assertRaises(ValueError) as cm: self._lib.apply_op("Binary", a=a, b=b) - self.assertTrue("must be from the same graph" in cm.exception.message) + self.assertTrue("must be from the same graph" in str(cm.exception)) def testDifferentGraphFailsWithGraphArgument(self): other_g = ops.Graph() @@ -1416,7 +1416,7 @@ class OpDefLibraryGraphTest(test_util.TensorFlowTestCase): with self.assertRaises(ValueError) as cm: self._lib.apply_op("Binary", a=a, b=b, g=self._g) self.assertTrue( - "not from the passed-in graph" in cm.exception.message) + "not from the passed-in graph" in str(cm.exception)) if __name__ == "__main__": diff --git a/tensorflow/python/ops/parsing_ops.py b/tensorflow/python/ops/parsing_ops.py index a6408f9f0a6..87148af2fde 100644 --- a/tensorflow/python/ops/parsing_ops.py +++ b/tensorflow/python/ops/parsing_ops.py @@ -88,12 +88,12 @@ def parse_example(serialized, ``` serialized = [ - features: - { feature: [ key: { "ft" value: float_list: { value: [1.0, 2.0] } } ] }, - features: - { feature: [] }, - features: - { feature: [ key: { "ft" value: float_list: { value: [3.0] } } ] } + features + { feature { key: "ft" value { float_list { value: [1.0, 2.0] } } } }, + features + { feature []}, + features + { feature { key: "ft" value { float_list { value: [3.0] } } } ] ``` @@ -109,14 +109,14 @@ def parse_example(serialized, ``` [ - features: { - feature: { key: "kw" value: { bytes_list: { value: [ "knit", "big" ] } } } - feature: { key: "gps" value: { float_list: { value: [] } } } + features { + feature { key: "kw" value { bytes_list { value: [ "knit", "big" ] } } } + feature { key: "gps" value { float_list { value: [] } } } }, - features: { - feature: { key: "kw" value: { bytes_list: { value: [ "emmy" ] } } } - feature: { key: "dank" value: { int64_list: { value: [ 42 ] } } } - feature: { key: "gps" value: { } } + features { + feature { key: "kw" value { bytes_list { value: [ "emmy" ] } } } + feature { key: "dank" value { int64_list { value: [ 42 ] } } } + feature { key: "gps" value { } } } ] ``` @@ -152,13 +152,13 @@ def parse_example(serialized, ``` [ - features: { - feature: { key: "age" value: { int64_list: { value: [ 0 ] } } } - feature: { key: "gender" value: { bytes_list: { value: [ "f" ] } } } + features { + feature { key: "age" value { int64_list { value: [ 0 ] } } } + feature { key: "gender" value { bytes_list { value: [ "f" ] } } } }, - features: { - feature: { key: "age" value: { int64_list: { value: [] } } } - feature: { key: "gender" value: { bytes_list: { value: [ "f" ] } } } + features { + feature { key: "age" value { int64_list { value: [] } } } + feature { key: "gender" value { bytes_list { value: [ "f" ] } } } } ] ``` @@ -204,6 +204,8 @@ def parse_example(serialized, The keys of the dict must match the dense_keys of the feature. dense_shapes: A list of tuples with the same length as `dense_keys`. The shape of the data for each dense feature referenced by `dense_keys`. + Required for any input tensors identified by dense_keys whose shapes are + anything other than [] or [1]. name: A name for this operation (optional). Returns: @@ -297,22 +299,22 @@ def parse_single_example(serialized, # pylint: disable=invalid-name For `SparseTensor`s, the first (batch) column of the indices matrix is removed (the indices matrix is a column vector), the values vector is unchanged, and - the first (batch_size) entry of the shape vector is removed (it is now a + the first (`batch_size`) entry of the shape vector is removed (it is now a single element vector). See also `parse_example`. Args: serialized: A scalar string Tensor, a single serialized Example. - See parse_example documentation for more details. + See `parse_example` documentation for more details. names: (Optional) A scalar string Tensor, the associated name. - See parse_example documentation for more details. - sparse_keys: See parse_example documentation for more details. - sparse_types: See parse_example documentation for more details. - dense_keys: See parse_example documentation for more details. - dense_types: See parse_example documentation for more details. - dense_defaults: See parse_example documentation for more details. - dense_shapes: See parse_example documentation for more details. + See `parse_example` documentation for more details. + sparse_keys: See `parse_example` documentation for more details. + sparse_types: See `parse_example` documentation for more details. + dense_keys: See `parse_example` documentation for more details. + dense_types: See `parse_example` documentation for more details. + dense_defaults: See `parse_example` documentation for more details. + dense_shapes: See `parse_example` documentation for more details. name: A name for this operation (optional). Returns: diff --git a/tensorflow/python/ops/random_ops.py b/tensorflow/python/ops/random_ops.py index 37b1cb115e8..cefeec54bba 100644 --- a/tensorflow/python/ops/random_ops.py +++ b/tensorflow/python/ops/random_ops.py @@ -19,10 +19,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util -from tensorflow.python.framework import types from tensorflow.python.framework import random_seed from tensorflow.python.ops import common_shapes from tensorflow.python.ops import gen_random_ops @@ -35,14 +35,14 @@ from tensorflow.python.ops.gen_random_ops import * def _ShapeTensor(shape): """Convert to an int32 or int64 tensor, defaulting to int32 if empty.""" if isinstance(shape, (tuple, list)) and not shape: - dtype = types.int32 + dtype = dtypes.int32 else: dtype = None return ops.convert_to_tensor(shape, dtype=dtype, name="shape") # pylint: disable=protected-access -def random_normal(shape, mean=0.0, stddev=1.0, dtype=types.float32, +def random_normal(shape, mean=0.0, stddev=1.0, dtype=dtypes.float32, seed=None, name=None): """Outputs random values from a normal distribution. @@ -80,7 +80,7 @@ def random_normal(shape, mean=0.0, stddev=1.0, dtype=types.float32, ops.NoGradient("RandomStandardNormal") -def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=types.float32, +def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=dtypes.float32, seed=None, name=None): """Outputs random values from a truncated normal distribution. @@ -123,7 +123,7 @@ ops.NoGradient("TruncatedNormal") def random_uniform(shape, minval=0.0, maxval=1.0, - dtype=types.float32, seed=None, + dtype=dtypes.float32, seed=None, name=None): """Outputs random values from a uniform distribution. diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py index d1682459890..1507371b408 100644 --- a/tensorflow/python/ops/sparse_ops.py +++ b/tensorflow/python/ops/sparse_ops.py @@ -46,10 +46,10 @@ import tensorflow.python.platform import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util -from tensorflow.python.framework import types from tensorflow.python.ops import array_ops from tensorflow.python.ops import constant_op from tensorflow.python.ops import gen_sparse_ops @@ -178,7 +178,7 @@ def sparse_reorder(sp_input, name=None): Reordering does not affect the shape of the `SparseTensor`. - For example, if sp_input has shape `[4, 5]` and `indices` / `values`: + For example, if `sp_input` has shape `[4, 5]` and `indices` / `values`: [0, 3]: b [0, 1]: a @@ -334,8 +334,8 @@ def sparse_to_indicator(sp_input, vocab_size, name=None): rank = indices_shape[1] ids = sp_input.values - if ids.dtype != types.int64: - ids = math_ops.cast(ids, types.int64) + if ids.dtype != dtypes.int64: + ids = math_ops.cast(ids, dtypes.int64) # Slice off the last dimension of indices, then then tack on the ids indices_columns_to_preserve = array_ops.slice( @@ -451,8 +451,8 @@ def sparse_fill_empty_rows(sp_input, default_value, name=None): default_value = ops.convert_to_tensor( default_value, dtype=sp_input.values.dtype) - num_rows = math_ops.cast(sp_input.shape[0], types.int32) - all_row_indices = math_ops.cast(math_ops.range(num_rows), types.int64) + num_rows = math_ops.cast(sp_input.shape[0], dtypes.int32) + all_row_indices = math_ops.cast(math_ops.range(num_rows), dtypes.int64) empty_row_indices, _ = array_ops.list_diff( all_row_indices, sp_input.indices[:, 0]) empty_row_indicator = gen_sparse_ops.sparse_to_dense( diff --git a/tensorflow/python/ops/sparse_ops_test.py b/tensorflow/python/ops/sparse_ops_test.py index 046625eefa2..c6e91dcd718 100644 --- a/tensorflow/python/ops/sparse_ops_test.py +++ b/tensorflow/python/ops/sparse_ops_test.py @@ -23,9 +23,9 @@ import tensorflow.python.platform import numpy as np +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util -from tensorflow.python.framework import types from tensorflow.python.ops import constant_op from tensorflow.python.ops import sparse_ops from tensorflow.python.platform import googletest @@ -41,9 +41,9 @@ class SparseToIndicatorTest(test_util.TensorFlowTestCase): val = np.array([0, 10, 13, 14, 32, 33]) shape = np.array([5, 6]) return ops.SparseTensor( - constant_op.constant(ind, types.int64), + constant_op.constant(ind, dtypes.int64), constant_op.constant(val, dtype), - constant_op.constant(shape, types.int64)) + constant_op.constant(shape, dtypes.int64)) def _SparseTensor_2x3x4(self, dtype): ind = np.array([ @@ -55,13 +55,13 @@ class SparseToIndicatorTest(test_util.TensorFlowTestCase): val = np.array([1, 10, 12, 103, 111, 113, 122]) shape = np.array([2, 3, 4]) return ops.SparseTensor( - constant_op.constant(ind, types.int64), + constant_op.constant(ind, dtypes.int64), constant_op.constant(val, dtype), - constant_op.constant(shape, types.int64)) + constant_op.constant(shape, dtypes.int64)) def testInt32(self): with self.test_session(use_gpu=False): - sp_input = self._SparseTensor_5x6(types.int32) + sp_input = self._SparseTensor_5x6(dtypes.int32) output = sparse_ops.sparse_to_indicator(sp_input, 50).eval() expected_output = np.zeros((5, 50), dtype=np.bool) @@ -73,7 +73,7 @@ class SparseToIndicatorTest(test_util.TensorFlowTestCase): def testInt64(self): with self.test_session(use_gpu=False): - sp_input = self._SparseTensor_5x6(types.int64) + sp_input = self._SparseTensor_5x6(dtypes.int64) output = sparse_ops.sparse_to_indicator(sp_input, 50).eval() expected_output = np.zeros((5, 50), dtype=np.bool) @@ -85,7 +85,7 @@ class SparseToIndicatorTest(test_util.TensorFlowTestCase): def testHigherRank(self): with self.test_session(use_gpu=False): - sp_input = self._SparseTensor_2x3x4(types.int64) + sp_input = self._SparseTensor_2x3x4(dtypes.int64) output = sparse_ops.sparse_to_indicator(sp_input, 200).eval() expected_output = np.zeros((2, 3, 200), dtype=np.bool) @@ -107,9 +107,9 @@ class SparseRetainTest(test_util.TensorFlowTestCase): val = np.array([0, 10, 13, 14, 32, 33]) shape = np.array([5, 6]) return ops.SparseTensor( - constant_op.constant(ind, types.int64), - constant_op.constant(val, types.int32), - constant_op.constant(shape, types.int64)) + constant_op.constant(ind, dtypes.int64), + constant_op.constant(val, dtypes.int32), + constant_op.constant(shape, dtypes.int64)) def testBasic(self): with self.test_session(use_gpu=False) as sess: @@ -153,9 +153,9 @@ class SparseFillEmptyRowsTest(test_util.TensorFlowTestCase): val = np.array([0, 10, 13, 14, 32, 33]) shape = np.array([5, 6]) return ops.SparseTensor( - constant_op.constant(ind, types.int64), - constant_op.constant(val, types.int32), - constant_op.constant(shape, types.int64)) + constant_op.constant(ind, dtypes.int64), + constant_op.constant(val, dtypes.int32), + constant_op.constant(shape, dtypes.int64)) def _SparseTensor_String5x6(self): ind = np.array([ @@ -165,18 +165,18 @@ class SparseFillEmptyRowsTest(test_util.TensorFlowTestCase): val = np.array(["a", "b", "c", "d", "e", "f"]) shape = np.array([5, 6]) return ops.SparseTensor( - constant_op.constant(ind, types.int64), - constant_op.constant(val, types.string), - constant_op.constant(shape, types.int64)) + constant_op.constant(ind, dtypes.int64), + constant_op.constant(val, dtypes.string), + constant_op.constant(shape, dtypes.int64)) def _SparseTensor_2x6(self): ind = np.array([[0, 0], [1, 0], [1, 3], [1, 4]]) val = np.array([0, 10, 13, 14]) shape = np.array([2, 6]) return ops.SparseTensor( - constant_op.constant(ind, types.int64), - constant_op.constant(val, types.int32), - constant_op.constant(shape, types.int64)) + constant_op.constant(ind, dtypes.int64), + constant_op.constant(val, dtypes.int32), + constant_op.constant(shape, dtypes.int64)) def testFillNumber(self): with self.test_session(use_gpu=False) as sess: @@ -207,7 +207,8 @@ class SparseFillEmptyRowsTest(test_util.TensorFlowTestCase): self.assertAllEqual( output.indices, [[0, 0], [1, 0], [1, 3], [1, 4], [2, 0], [3, 2], [3, 3], [4, 0]]) - self.assertAllEqual(output.values, ["a", "b", "c", "d", "", "e", "f", ""]) + self.assertAllEqual(output.values, + [b"a", b"b", b"c", b"d", b"", b"e", b"f", b""]) self.assertAllEqual(output.shape, [5, 6]) self.assertAllEqual(empty_row_indicator_out, np.array([0, 0, 1, 0, 1]).astype(np.bool)) diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py index 8ccbad16c60..c5b490ce4ce 100644 --- a/tensorflow/python/ops/variable_scope.py +++ b/tensorflow/python/ops/variable_scope.py @@ -22,9 +22,9 @@ from __future__ import print_function import contextlib import six +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import types from tensorflow.python.ops import init_ops from tensorflow.python.ops import variables from tensorflow.python.platform import logging @@ -45,7 +45,7 @@ class _VariableStore(object): """Create a variable store.""" self._vars = {} # A dictionary of the stored TensorFlow variables. - def get_variable(self, name, shape=None, dtype=types.float32, + def get_variable(self, name, shape=None, dtype=dtypes.float32, initializer=None, reuse=None, trainable=True, collections=None): """Gets an existing variable with these parameters or create a new one. @@ -82,7 +82,7 @@ class _VariableStore(object): or when violating reuse during variable creation. """ should_check = reuse is not None - dtype = types.as_dtype(dtype) + dtype = dtypes.as_dtype(dtype) shape = tensor_shape.as_shape(shape) if name in self._vars: # Here we handle the case when returning an existing variable. @@ -158,7 +158,7 @@ class _VariableScope(object): """Set initializer for this scope.""" self._initializer = initializer - def get_variable(self, var_store, name, shape=None, dtype=types.float32, + def get_variable(self, var_store, name, shape=None, dtype=dtypes.float32, initializer=None, trainable=True, collections=None): """Gets an existing variable with this name or create a new one.""" if initializer is None and self._initializer: @@ -194,7 +194,7 @@ def _get_default_variable_store(): return store -def get_variable(name, shape=None, dtype=types.float32, initializer=None, +def get_variable(name, shape=None, dtype=dtypes.float32, initializer=None, trainable=True, collections=None): """Gets an existing variable with these parameters or create a new one. diff --git a/tensorflow/python/platform/base.i b/tensorflow/python/platform/base.i index 22792669fa9..f08ac17c5d1 100644 --- a/tensorflow/python/platform/base.i +++ b/tensorflow/python/platform/base.i @@ -1,3 +1,18 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + // Helper macros and typemaps for use in Tensorflow swig files. // %{ @@ -9,11 +24,13 @@ template<class T> bool _PyObjAs(PyObject *pystr, T* cstr) { T::undefined; // You need to define specialization _PyObjAs<T> + return NULL; } template<class T> PyObject *_PyObjFrom(const T& c) { T::undefined; // You need to define specialization _PyObjFrom<T> + return NULL; } #ifdef HAS_GLOBAL_STRING @@ -45,9 +62,19 @@ return PyBytes_FromStringAndSize(c.data(), c.size()); } - PyObject* _SwigString_FromString(const string& s) { + PyObject* _SwigBytes_FromString(const string& s) { return PyBytes_FromStringAndSize(s.data(), s.size()); } + + // The string must be both ASCII and Unicode compatible, so this routine + // should be used only for error messages and the like. + PyObject* _SwigSimpleStr_FromString(const string& s) { +#if PY_MAJOR_VERSION < 3 + return PyString_FromStringAndSize(s.data(), s.size()); +#else + return PyUnicode_FromStringAndSize(s.data(), s.size()); +#endif + } %} %typemap(in) string { @@ -118,7 +145,7 @@ std::vector<type>* OUTPUT (std::vector<type> temp), } %enddef -_LIST_OUTPUT_TYPEMAP(string, _SwigString_FromString); +_LIST_OUTPUT_TYPEMAP(string, _SwigBytes_FromString); _LIST_OUTPUT_TYPEMAP(unsigned long long, PyLong_FromUnsignedLongLong); %typemap(in) uint64 { diff --git a/tensorflow/python/platform/default/_gfile.py b/tensorflow/python/platform/default/_gfile.py index fbc4aef409f..4ee28ca0123 100644 --- a/tensorflow/python/platform/default/_gfile.py +++ b/tensorflow/python/platform/default/_gfile.py @@ -24,37 +24,14 @@ import functools import glob as _glob import os import shutil +import six import threading -class FileError(IOError): - """An error occurred while reading or writing a file.""" - - -class GOSError(OSError): - """An error occurred while finding a file or in handling pathnames.""" - - -class _GFileBase(object): +class _GFileBase(six.Iterator): """Base I/O wrapper class. Similar semantics to Python's file object.""" # pylint: disable=protected-access - def _error_wrapper(fn): - """Decorator wrapping GFileBase class method errors.""" - @functools.wraps(fn) # Preserve methods' __doc__ - def wrap(self, *args, **kwargs): - try: - return fn(self, *args, **kwargs) - except ValueError as e: - # Sometimes a ValueError is raised, e.g., a read() on a closed file. - raise FileError(errno.EIO, e.message, self._name) - except IOError as e: - e.filename = self._name - raise FileError(e) - except OSError as e: - raise GOSError(e) - return wrap - def _synchronized(fn): """Synchronizes file I/O for methods in GFileBase.""" @functools.wraps(fn) @@ -69,7 +46,6 @@ class _GFileBase(object): return sync # pylint: enable=protected-access - @_error_wrapper def __init__(self, name, mode, locker): """Create the GFileBase object with the given filename, mode, and locker. @@ -92,7 +68,6 @@ class _GFileBase(object): """Make GFileBase usable with "with" statement.""" self.close() - @_error_wrapper @_synchronized def __del__(self): # __del__ is sometimes called before initialization, in which @@ -100,20 +75,17 @@ class _GFileBase(object): # before trying to close the file handle. if hasattr(self, '_fp'): self._fp.close() - @_error_wrapper @_synchronized def flush(self): """Flush the underlying file handle.""" return self._fp.flush() @property - @_error_wrapper @_synchronized def closed(self): """Returns "True" if the file handle is closed. Otherwise False.""" return self._fp.closed - @_error_wrapper @_synchronized def write(self, data): """Write data to the underlying file handle. @@ -123,13 +95,11 @@ class _GFileBase(object): """ self._fp.write(data) - @_error_wrapper @_synchronized def writelines(self, seq): """Write a sequence of strings to the underlying file handle.""" self._fp.writelines(seq) - @_error_wrapper @_synchronized def tell(self): """Return the location from the underlying file handle. @@ -139,7 +109,6 @@ class _GFileBase(object): """ return self._fp.tell() - @_error_wrapper @_synchronized def seek(self, offset, whence=0): """Seek to offset (conditioned on whence) in the underlying file handle. @@ -150,7 +119,6 @@ class _GFileBase(object): """ self._fp.seek(offset, whence) - @_error_wrapper @_synchronized def truncate(self, new_size=None): """Truncate the underlying file handle to new_size. @@ -161,7 +129,6 @@ class _GFileBase(object): """ self._fp.truncate(new_size) - @_error_wrapper @_synchronized def readline(self, max_length=-1): """Read a single line (up to max_length) from the underlying file handle. @@ -174,7 +141,6 @@ class _GFileBase(object): """ return self._fp.readline(max_length) - @_error_wrapper @_synchronized def readlines(self, sizehint=None): """Read lines from the underlying file handle. @@ -195,8 +161,7 @@ class _GFileBase(object): return self # Not synchronized - @_error_wrapper - def next(self): + def __next__(self): """Enable line iteration on the underlying handle (not synchronized). Returns: @@ -208,7 +173,6 @@ class _GFileBase(object): """ return next(self._fp) - @_error_wrapper @_synchronized def Size(self): # pylint: disable=invalid-name """Get byte size of the file from the underlying file handle.""" @@ -220,7 +184,6 @@ class _GFileBase(object): self.seek(cur) return size - @_error_wrapper @_synchronized def read(self, n=-1): """Read n bytes from the underlying file handle. @@ -233,7 +196,6 @@ class _GFileBase(object): """ return self._fp.read(n) - @_error_wrapper @_synchronized def close(self): """Close the underlying file handle.""" @@ -241,7 +203,6 @@ class _GFileBase(object): # Declare wrappers as staticmethods at the end so that we can # use them as decorators. - _error_wrapper = staticmethod(_error_wrapper) _synchronized = staticmethod(_synchronized) @@ -284,40 +245,21 @@ class _Nulllocker(object): pass -def _func_error_wrapper(fn): - """Decorator wrapping function errors.""" - @functools.wraps(fn) # Preserve methods' __doc__ - def wrap(*args, **kwargs): - try: - return fn(*args, **kwargs) - except ValueError as e: - raise FileError(errno.EIO, e.message) - except IOError as e: - raise FileError(e) - except OSError as e: - raise GOSError(e) - return wrap - - -@_func_error_wrapper def Exists(path): # pylint: disable=invalid-name """Retruns True iff "path" exists (as a dir, file, non-broken symlink).""" return os.path.exists(path) -@_func_error_wrapper def IsDirectory(path): # pylint: disable=invalid-name """Return True iff "path" exists and is a directory.""" return os.path.isdir(path) -@_func_error_wrapper def Glob(glob): # pylint: disable=invalid-name """Return a list of filenames matching the glob "glob".""" return _glob.glob(glob) -@_func_error_wrapper def MkDir(path, mode=0o755): # pylint: disable=invalid-name """Create the directory "path" with the given mode. @@ -329,12 +271,11 @@ def MkDir(path, mode=0o755): # pylint: disable=invalid-name None Raises: - GOSError: if the path already exists + OSError: if the path already exists """ os.mkdir(path, mode) -@_func_error_wrapper def MakeDirs(path, mode=0o755): # pylint: disable=invalid-name """Recursively create the directory "path" with the given mode. @@ -347,12 +288,11 @@ def MakeDirs(path, mode=0o755): # pylint: disable=invalid-name Raises: - GOSError: if the path already exists + OSError: if the path already exists """ os.makedirs(path, mode) -@_func_error_wrapper def RmDir(directory): # pylint: disable=invalid-name """Removes the directory "directory" iff the directory is empty. @@ -360,12 +300,11 @@ def RmDir(directory): # pylint: disable=invalid-name directory: The directory to remove. Raises: - GOSError: If the directory does not exist or is not empty. + OSError: If the directory does not exist or is not empty. """ os.rmdir(directory) -@_func_error_wrapper def Remove(path): # pylint: disable=invalid-name """Delete the (non-directory) file "path". @@ -373,12 +312,11 @@ def Remove(path): # pylint: disable=invalid-name path: The file to remove. Raises: - GOSError: If "path" does not exist, is a directory, or cannot be deleted. + OSError: If "path" does not exist, is a directory, or cannot be deleted. """ os.remove(path) -@_func_error_wrapper def DeleteRecursively(path): # pylint: disable=invalid-name """Delete the file or directory "path" recursively. @@ -386,7 +324,7 @@ def DeleteRecursively(path): # pylint: disable=invalid-name path: The path to remove (may be a non-empty directory). Raises: - GOSError: If the path does not exist or cannot be deleted. + OSError: If the path does not exist or cannot be deleted. """ if IsDirectory(path): shutil.rmtree(path) @@ -394,7 +332,6 @@ def DeleteRecursively(path): # pylint: disable=invalid-name Remove(path) -@_func_error_wrapper def ListDirectory(directory, return_dotfiles=False): # pylint: disable=invalid-name """Returns a list of files in dir. @@ -415,7 +352,7 @@ def ListDirectory(directory, return_dotfiles=False): # pylint: disable=invalid- Other entries starting with a dot will only be returned if return_dotfiles is True. Raises: - GOSError: if there is an error retrieving the directory listing. + OSError: if there is an error retrieving the directory listing. """ files = os.listdir(directory) if not return_dotfiles: diff --git a/tensorflow/python/platform/default/_logging.py b/tensorflow/python/platform/default/_logging.py index 37cf28eb560..9212a6ea346 100644 --- a/tensorflow/python/platform/default/_logging.py +++ b/tensorflow/python/platform/default/_logging.py @@ -23,9 +23,9 @@ from __future__ import print_function import logging import os +import six import sys import time -import thread from logging import DEBUG from logging import ERROR from logging import FATAL @@ -199,7 +199,7 @@ def set_verbosity(verbosity): def _get_thread_id(): """Get id of current thread, suitable for logging as an unsigned quantity.""" - thread_id = thread.get_ident() + thread_id = six.moves._thread.get_ident() return thread_id & _THREAD_ID_MASK diff --git a/tensorflow/python/platform/default/gfile_test.py b/tensorflow/python/platform/default/gfile_test.py index fc9fc5f55ed..023dac56de5 100644 --- a/tensorflow/python/platform/default/gfile_test.py +++ b/tensorflow/python/platform/default/gfile_test.py @@ -57,7 +57,7 @@ class _GFileBaseTest(_BaseTest): with self.gfile(self.tmp + "test_with", "w") as fh: fh.write("hi") with self.gfile(self.tmp + "test_with", "r") as fh: - self.assertEquals(fh.read(), "hi") + self.assertEqual(fh.read(), "hi") def testSizeAndTellAndSeek(self): with self.gfile(self.tmp + "test_tell", "w") as fh: @@ -90,17 +90,19 @@ class _GFileBaseTest(_BaseTest): def testErrors(self): self.assertRaises( - gfile.FileError, lambda: self.gfile(self.tmp + "doesnt_exist", "r")) + IOError, lambda: self.gfile(self.tmp + "doesnt_exist", "r")) with self.gfile(self.tmp + "test_error", "w") as fh: - self.assertRaises(gfile.FileError, lambda: fh.seek(-1)) + # Raises FileError inside Google and ValueError outside, so we + # can only test for Exception. + self.assertRaises(Exception, lambda: fh.seek(-1)) # test_error now exists, we can read from it: with self.gfile(self.tmp + "test_error", "r") as fh: - self.assertRaises(gfile.FileError, lambda: fh.write("ack")) + self.assertRaises(IOError, lambda: fh.write("ack")) fh = self.gfile(self.tmp + "test_error", "w") self.assertFalse(fh.closed) fh.close() self.assertTrue(fh.closed) - self.assertRaises(gfile.FileError, lambda: fh.write("ack")) + self.assertRaises(ValueError, lambda: fh.write("ack")) def testIteration(self): with self.gfile(self.tmp + "test_iter", "w") as fh: @@ -147,16 +149,16 @@ class FunctionTests(_BaseTest, googletest.TestCase): def testErrors(self): self.assertRaises( - gfile.GOSError, lambda: gfile.RmDir(self.tmp + "dir_doesnt_exist")) + OSError, lambda: gfile.RmDir(self.tmp + "dir_doesnt_exist")) self.assertRaises( - gfile.GOSError, lambda: gfile.Remove(self.tmp + "file_doesnt_exist")) + OSError, lambda: gfile.Remove(self.tmp + "file_doesnt_exist")) gfile.MkDir(self.tmp + "error_dir") with gfile.GFile(self.tmp + "error_dir/file", "w"): pass # Create file self.assertRaises( - gfile.GOSError, lambda: gfile.Remove(self.tmp + "error_dir")) + OSError, lambda: gfile.Remove(self.tmp + "error_dir")) self.assertRaises( - gfile.GOSError, lambda: gfile.RmDir(self.tmp + "error_dir")) + OSError, lambda: gfile.RmDir(self.tmp + "error_dir")) self.assertTrue(gfile.Exists(self.tmp + "error_dir")) gfile.DeleteRecursively(self.tmp + "error_dir") self.assertFalse(gfile.Exists(self.tmp + "error_dir")) diff --git a/tensorflow/python/summary/event_accumulator_test.py b/tensorflow/python/summary/event_accumulator_test.py index 27c44652ed1..3cc7e493d04 100644 --- a/tensorflow/python/summary/event_accumulator_test.py +++ b/tensorflow/python/summary/event_accumulator_test.py @@ -58,7 +58,7 @@ class _EventGenerator(object): summary=tf.Summary(value=[tf.Summary.Value(tag=tag, histo=histo)])) self.AddEvent(event) - def AddImage(self, tag, wall_time=0, step=0, encoded_image_string='imgstr', + def AddImage(self, tag, wall_time=0, step=0, encoded_image_string=b'imgstr', width=150, height=100): image = tf.Summary.Image(encoded_image_string=encoded_image_string, width=width, height=height) @@ -307,13 +307,13 @@ class MockingEventAccumulatorTest(EventAccumulatorTest): def testImages(self): gen = _EventGenerator() acc = ea.EventAccumulator(gen) - im1 = ea.ImageEvent(wall_time=1, step=10, encoded_image_string='big', + im1 = ea.ImageEvent(wall_time=1, step=10, encoded_image_string=b'big', width=400, height=300) - im2 = ea.ImageEvent(wall_time=2, step=12, encoded_image_string='small', + im2 = ea.ImageEvent(wall_time=2, step=12, encoded_image_string=b'small', width=40, height=30) - gen.AddImage('im1', wall_time=1, step=10, encoded_image_string='big', + gen.AddImage('im1', wall_time=1, step=10, encoded_image_string=b'big', width=400, height=300) - gen.AddImage('im2', wall_time=2, step=12, encoded_image_string='small', + gen.AddImage('im2', wall_time=2, step=12, encoded_image_string=b'small', width=40, height=30) acc.Reload() self.assertEqual(acc.Images('im1'), [im1]) diff --git a/tensorflow/python/summary/impl/directory_watcher.py b/tensorflow/python/summary/impl/directory_watcher.py index 5a97106740e..587c7a6d30c 100644 --- a/tensorflow/python/summary/impl/directory_watcher.py +++ b/tensorflow/python/summary/impl/directory_watcher.py @@ -56,7 +56,7 @@ class DirectoryWatcher(object): self._directory = directory self._loader_factory = loader_factory self._loader = None - self._path = None + self._path = '' self._path_filter = path_filter def Load(self): diff --git a/tensorflow/python/summary/impl/event_file_loader.py b/tensorflow/python/summary/impl/event_file_loader.py index 29e1b04194e..bd5deb0e0ac 100644 --- a/tensorflow/python/summary/impl/event_file_loader.py +++ b/tensorflow/python/summary/impl/event_file_loader.py @@ -22,6 +22,7 @@ from tensorflow.core.util import event_pb2 from tensorflow.python import pywrap_tensorflow from tensorflow.python.platform import app from tensorflow.python.platform import logging +from tensorflow.python.util import compat class EventFileLoader(object): @@ -31,7 +32,8 @@ class EventFileLoader(object): if file_path is None: raise ValueError('A file path is required') logging.debug('Opening a record reader pointing at %s', file_path) - self._reader = pywrap_tensorflow.PyRecordReader_New(file_path, 0) + self._reader = pywrap_tensorflow.PyRecordReader_New( + compat.as_bytes(file_path), 0) # Store it for logging purposes. self._file_path = file_path if not self._reader: diff --git a/tensorflow/python/summary/impl/event_file_loader_test.py b/tensorflow/python/summary/impl/event_file_loader_test.py index 6220970084f..4f6ac75d839 100644 --- a/tensorflow/python/summary/impl/event_file_loader_test.py +++ b/tensorflow/python/summary/impl/event_file_loader_test.py @@ -28,8 +28,8 @@ from tensorflow.python.summary.impl import event_file_loader class EventFileLoaderTest(test_util.TensorFlowTestCase): # A record containing a simple event. - RECORD = ('\x18\x00\x00\x00\x00\x00\x00\x00\xa3\x7fK"\t\x00\x00\xc0%\xddu' - '\xd5A\x1a\rbrain.Event:1\xec\xf32\x8d') + RECORD = (b'\x18\x00\x00\x00\x00\x00\x00\x00\xa3\x7fK"\t\x00\x00\xc0%\xddu' + b'\xd5A\x1a\rbrain.Event:1\xec\xf32\x8d') def _WriteToFile(self, filename, data): path = os.path.join(self.get_temp_dir(), filename) @@ -41,37 +41,37 @@ class EventFileLoaderTest(test_util.TensorFlowTestCase): os.path.join(self.get_temp_dir(), filename)) def testEmptyEventFile(self): - self._WriteToFile('empty_event_file', '') + self._WriteToFile('empty_event_file', b'') loader = self._LoaderForTestFile('empty_event_file') - self.assertEquals(len(list(loader.Load())), 0) + self.assertEqual(len(list(loader.Load())), 0) def testSingleWrite(self): self._WriteToFile('single_event_file', EventFileLoaderTest.RECORD) loader = self._LoaderForTestFile('single_event_file') events = list(loader.Load()) - self.assertEquals(len(events), 1) - self.assertEquals(events[0].wall_time, 1440183447.0) - self.assertEquals(len(list(loader.Load())), 0) + self.assertEqual(len(events), 1) + self.assertEqual(events[0].wall_time, 1440183447.0) + self.assertEqual(len(list(loader.Load())), 0) def testMultipleWrites(self): self._WriteToFile('staggered_event_file', EventFileLoaderTest.RECORD) loader = self._LoaderForTestFile('staggered_event_file') - self.assertEquals(len(list(loader.Load())), 1) + self.assertEqual(len(list(loader.Load())), 1) self._WriteToFile('staggered_event_file', EventFileLoaderTest.RECORD) - self.assertEquals(len(list(loader.Load())), 1) + self.assertEqual(len(list(loader.Load())), 1) def testMultipleLoads(self): self._WriteToFile('multiple_loads_event_file', EventFileLoaderTest.RECORD) loader = self._LoaderForTestFile('multiple_loads_event_file') loader.Load() loader.Load() - self.assertEquals(len(list(loader.Load())), 1) + self.assertEqual(len(list(loader.Load())), 1) def testMultipleWritesAtOnce(self): self._WriteToFile('multiple_event_file', EventFileLoaderTest.RECORD) self._WriteToFile('multiple_event_file', EventFileLoaderTest.RECORD) loader = self._LoaderForTestFile('staggered_event_file') - self.assertEquals(len(list(loader.Load())), 2) + self.assertEqual(len(list(loader.Load())), 2) if __name__ == '__main__': diff --git a/tensorflow/python/summary/impl/reservoir.py b/tensorflow/python/summary/impl/reservoir.py index 44b3b2a58ce..c5b5daff0c9 100644 --- a/tensorflow/python/summary/impl/reservoir.py +++ b/tensorflow/python/summary/impl/reservoir.py @@ -211,7 +211,7 @@ class _ReservoirBucket(object): """ with self._mutex: size_before = len(self.items) - self.items = filter(filterFn, self.items) + self.items = list(filter(filterFn, self.items)) size_diff = size_before - len(self.items) # Estimate a correction the the number of items seen diff --git a/tensorflow/python/summary/impl/reservoir_test.py b/tensorflow/python/summary/impl/reservoir_test.py index 0493d36ac6b..b7f72e64de9 100644 --- a/tensorflow/python/summary/impl/reservoir_test.py +++ b/tensorflow/python/summary/impl/reservoir_test.py @@ -20,11 +20,12 @@ from __future__ import print_function import tensorflow.python.platform from six.moves import xrange # pylint: disable=redefined-builtin -from tensorflow.python.platform import googletest +import tensorflow as tf + from tensorflow.python.summary.impl import reservoir -class ReservoirTest(googletest.TestCase): +class ReservoirTest(tf.test.TestCase): def testEmptyReservoir(self): r = reservoir.Reservoir(1) @@ -94,7 +95,7 @@ class ReservoirTest(googletest.TestCase): self.assertNotEqual(r1.Items(key), r2.Items(key)) -class ReservoirBucketTest(googletest.TestCase): +class ReservoirBucketTest(tf.test.TestCase): def testEmptyBucket(self): b = reservoir._ReservoirBucket(1) @@ -119,7 +120,7 @@ class ReservoirBucketTest(googletest.TestCase): for i in xrange(10000): b.AddItem(i) items = b.Items() - prev = None + prev = -1 for item in items: self.assertTrue(item > prev) prev = item @@ -175,7 +176,7 @@ class ReservoirBucketTest(googletest.TestCase): int(round(10000 * (1 - float(num_removed) / 100)))) -class ReservoirBucketStatisticalDistributionTest(googletest.TestCase): +class ReservoirBucketStatisticalDistributionTest(tf.test.TestCase): def setUp(self): self.total = 1000000 @@ -222,4 +223,4 @@ class ReservoirBucketStatisticalDistributionTest(googletest.TestCase): if __name__ == '__main__': - googletest.main() + tf.test.main() diff --git a/tensorflow/python/tensorflow.i b/tensorflow/python/tensorflow.i index d26f12a89ce..ce9770b3f47 100644 --- a/tensorflow/python/tensorflow.i +++ b/tensorflow/python/tensorflow.i @@ -1,3 +1,18 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + /* SWIG wrapper for all of TensorFlow native functionality. * The includes are intentionally not alphabetically sorted, as the order of * includes follows dependency order */ diff --git a/tensorflow/python/training/adagrad.py b/tensorflow/python/training/adagrad.py index a1e5b966fe0..864e549b25f 100644 --- a/tensorflow/python/training/adagrad.py +++ b/tensorflow/python/training/adagrad.py @@ -43,7 +43,7 @@ class AdagradOptimizer(optimizer.Optimizer): gradients. Defaults to "Adagrad". Raises: - ValueError: If the initial_accumulator_value is invalid. + ValueError: If the `initial_accumulator_value` is invalid. """ if initial_accumulator_value <= 0.0: raise ValueError("initial_accumulator_value must be positive: %s" % diff --git a/tensorflow/python/training/coordinator.py b/tensorflow/python/training/coordinator.py index 08b55ece59b..805d00a441a 100644 --- a/tensorflow/python/training/coordinator.py +++ b/tensorflow/python/training/coordinator.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Coordinator to help multiple threads stop when requested.""" from __future__ import absolute_import from __future__ import division @@ -23,6 +22,9 @@ import threading import time from tensorflow.python.platform import logging +from tensorflow.python.util import compat + +import six class Coordinator(object): @@ -122,21 +124,23 @@ class Coordinator(object): def request_stop(self, ex=None): """Request that the threads stop. - After this is called, calls to should_stop() will return True. + After this is called, calls to `should_stop()` will return `True`. Args: - ex: Optional Exception, or Python 'exc_info' tuple as returned by - sys.exc_info(). If this is the first call to request_stop() the - corresponding exception is recorded and re-raised from join(). + ex: Optional `Exception`, or Python `exc_info` tuple as returned by + `sys.exc_info()`. If this is the first call to `request_stop()` the + corresponding exception is recorded and re-raised from `join()`. """ with self._lock: if not self._stop_event.is_set(): if ex and self._exc_info_to_raise is None: if isinstance(ex, tuple): - logging.info("Error reported to Coordinator: %s", str(ex[1])) + logging.info("Error reported to Coordinator: %s", + compat.as_str(unicode(ex[1]))) self._exc_info_to_raise = ex else: - logging.info("Error reported to Coordinator: %s", str(ex)) + logging.info("Error reported to Coordinator: %s", + compat.as_str(unicode(ex))) self._exc_info_to_raise = sys.exc_info() self._stop_event.set() @@ -163,24 +167,24 @@ class Coordinator(object): def join(self, threads, stop_grace_period_secs=120): """Wait for threads to terminate. - Blocks until all 'threads' have terminated or request_stop() is called. + Blocks until all `threads` have terminated or `request_stop()` is called. - After the threads stop, if an 'exc_info' was passed to request_stop, that + After the threads stop, if an `exc_info` was passed to `request_stop`, that exception is re-reaised. - Grace period handling: When request_stop() is called, threads are given + Grace period handling: When `request_stop()` is called, threads are given 'stop_grace_period_secs' seconds to terminate. If any of them is still - alive after that period expires, a RuntimeError is raised. Note that if - an 'exc_info' was passed to request_stop() then it is raised instead of - that RuntimeError. + alive after that period expires, a `RuntimeError` is raised. Note that if + an `exc_info` was passed to `request_stop()` then it is raised instead of + that `RuntimeError`. Args: - threads: List threading.Threads. The started threads to join. + threads: List of `threading.Threads`. The started threads to join. stop_grace_period_secs: Number of seconds given to threads to stop after - request_stop() has been called. + `request_stop()` has been called. Raises: - RuntimeError: If any thread is still alive after request_stop() + RuntimeError: If any thread is still alive after `request_stop()` is called and the grace period expires. """ # Wait for all threads to stop or for request_stop() to be called. @@ -198,8 +202,7 @@ class Coordinator(object): # Terminate with an exception if appropriate. with self._lock: if self._exc_info_to_raise: - exc_info = self._exc_info_to_raise - raise exc_info[0], exc_info[1], exc_info[2] + six.reraise(*self._exc_info_to_raise) elif stragglers: raise RuntimeError("Coordinator stopped with threads still running: %s", " ".join(stragglers)) diff --git a/tensorflow/python/training/ftrl.py b/tensorflow/python/training/ftrl.py index 1ac7d6bdc60..0a74df34d6a 100644 --- a/tensorflow/python/training/ftrl.py +++ b/tensorflow/python/training/ftrl.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import types from tensorflow.python.ops import array_ops from tensorflow.python.ops import constant_op from tensorflow.python.ops import control_flow_ops @@ -71,8 +71,8 @@ def _Compute(accum, linear, base_lr, lr_power, l1, l2): A Tensor which is "variable" after update """ with ops.name_scope("compute_" + accum.op.name): - one_t = constant_op.constant(1.0, dtype=types.float32) - two_t = constant_op.constant(2.0, dtype=types.float32) + one_t = constant_op.constant(1.0, dtype=dtypes.float32) + two_t = constant_op.constant(2.0, dtype=dtypes.float32) learning_rate = math_ops.pow(accum, lr_power) * base_lr quadratic = one_t / learning_rate + two_t * l2 w = _Solve(quadratic, linear, l1) diff --git a/tensorflow/python/training/input.py b/tensorflow/python/training/input.py index dfb414f4c27..e580e80eb27 100644 --- a/tensorflow/python/training/input.py +++ b/tensorflow/python/training/input.py @@ -24,9 +24,9 @@ from __future__ import division from __future__ import print_function from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import types from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import constant_op @@ -55,7 +55,7 @@ def match_filenames_once(pattern, name=None): def limit_epochs(tensor, num_epochs=None, name=None): - """Returns tensor num_epochs times and then raises an OutOfRange error. + """Returns tensor `num_epochs` times and then raises an `OutOfRange` error. Args: tensor: Any `Tensor`. @@ -64,14 +64,14 @@ def limit_epochs(tensor, num_epochs=None, name=None): name: A name for the operations (optional). Returns: - tensor or OutOfRange. + tensor or `OutOfRange`. """ if num_epochs is None: return tensor if num_epochs <= 0: raise ValueError("num_epochs must be > 0 not %d." % num_epochs) with ops.op_scope([tensor], name, "limit_epochs") as name: - zero64 = constant_op.constant(0, dtype=types.int64) + zero64 = constant_op.constant(0, dtype=dtypes.int64) epochs = variables.Variable(zero64, name="epochs") counter = epochs.count_up_to(num_epochs) with ops.control_dependencies([counter]): @@ -89,7 +89,7 @@ def _input_producer(input_tensor, dtype, num_epochs, shuffle, seed, capacity, enq = q.enqueue_many([input_tensor]) queue_runner.add_queue_runner(queue_runner.QueueRunner(q, [enq])) summary_ops.scalar_summary("queue/%s/%s" % (q.name, summary_name), - math_ops.cast(q.size(), types.float32) * + math_ops.cast(q.size(), dtypes.float32) * (1. / capacity)) return q @@ -117,7 +117,7 @@ def string_input_producer(string_tensor, num_epochs=None, shuffle=True, """ with ops.op_scope([string_tensor], name, "input_producer") as name: return _input_producer( - string_tensor, types.string, num_epochs, shuffle, seed, capacity, name, + string_tensor, dtypes.string, num_epochs, shuffle, seed, capacity, name, "fraction_of_%d_full" % capacity) @@ -144,7 +144,7 @@ def range_input_producer(limit, num_epochs=None, shuffle=True, seed=None, with ops.op_scope([limit], name, "input_producer") as name: range_tensor = math_ops.range(limit) return _input_producer( - range_tensor, types.int32, num_epochs, shuffle, seed, capacity, name, + range_tensor, dtypes.int32, num_epochs, shuffle, seed, capacity, name, "fraction_of_%d_full" % capacity) @@ -208,14 +208,14 @@ def _validate_join(tensor_list_list): def _dtypes(tensor_list_list): - all_dtypes = [[t.dtype for t in tl] for tl in tensor_list_list] - dtypes = all_dtypes[0] - for other_dtypes in all_dtypes[1:]: - if other_dtypes != dtypes: + all_types = [[t.dtype for t in tl] for tl in tensor_list_list] + types = all_types[0] + for other_types in all_types[1:]: + if other_types != types: raise TypeError("Expected types to be consistent: %s vs. %s." % - ", ".join(x.name for x in dtypes), - ", ".join(x.name for x in other_dtypes)) - return dtypes + ", ".join(x.name for x in types), + ", ".join(x.name for x in other_types)) + return types def _merge_shapes(shape_list, enqueue_many): @@ -274,6 +274,12 @@ def batch(tensor_list, batch_size, num_threads=1, capacity=32, output will have shape `[batch_size, x, y, z]`. The `capacity` argument controls the how long the prefetching is allowed to grow the queues. + The returned operation is a dequeue operation and will throw + `tf.errors.OutOfRangeError` if the input queue is exhausted. If this + operation is feeding another input queue, its queue runner will catch + this exception, however, if this operation is used in your main thread + you are responsible for catching this yourself. + *N.B.:* You must ensure that either (i) the `shapes` argument is passed, or (ii) all of the tensors in `tensor_list` must have fully-defined shapes. `ValueError` will be raised if neither of @@ -298,15 +304,15 @@ def batch(tensor_list, batch_size, num_threads=1, capacity=32, """ with ops.op_scope(tensor_list, name, "batch") as name: tensor_list = _validate(tensor_list) - dtypes = _dtypes([tensor_list]) + types = _dtypes([tensor_list]) shapes = _shapes([tensor_list], shapes, enqueue_many) # TODO(josh11b,mrry): Switch to BatchQueue once it is written. queue = data_flow_ops.FIFOQueue( - capacity=capacity, dtypes=dtypes, shapes=shapes) + capacity=capacity, dtypes=types, shapes=shapes) _enqueue(queue, tensor_list, num_threads, enqueue_many) summary_ops.scalar_summary( "queue/%s/fraction_of_%d_full" % (queue.name, capacity), - math_ops.cast(queue.size(), types.float32) * (1. / capacity)) + math_ops.cast(queue.size(), dtypes.float32) * (1. / capacity)) return queue.dequeue_many(batch_size, name=name) @@ -344,6 +350,12 @@ def batch_join(tensor_list_list, batch_size, capacity=32, enqueue_many=False, The `capacity` argument controls the how long the prefetching is allowed to grow the queues. + The returned operation is a dequeue operation and will throw + `tf.errors.OutOfRangeError` if the input queue is exhausted. If this + operation is feeding another input queue, its queue runner will catch + this exception, however, if this operation is used in your main thread + you are responsible for catching this yourself. + *N.B.:* You must ensure that either (i) the `shapes` argument is passed, or (ii) all of the tensors in `tensor_list_list` must have fully-defined shapes. `ValueError` will be raised if neither of @@ -369,15 +381,15 @@ def batch_join(tensor_list_list, batch_size, capacity=32, enqueue_many=False, """ with ops.op_scope(_flatten(tensor_list_list), name, "batch_join") as name: tensor_list_list = _validate_join(tensor_list_list) - dtypes = _dtypes(tensor_list_list) + types = _dtypes(tensor_list_list) shapes = _shapes(tensor_list_list, shapes, enqueue_many) # TODO(josh11b,mrry): Switch to BatchQueue once it is written. queue = data_flow_ops.FIFOQueue( - capacity=capacity, dtypes=dtypes, shapes=shapes) + capacity=capacity, dtypes=types, shapes=shapes) _enqueue_join(queue, tensor_list_list, enqueue_many) summary_ops.scalar_summary( "queue/%s/fraction_of_%d_full" % (queue.name, capacity), - math_ops.cast(queue.size(), types.float32) * (1. / capacity)) + math_ops.cast(queue.size(), dtypes.float32) * (1. / capacity)) return queue.dequeue_many(batch_size, name=name) @@ -406,6 +418,12 @@ def shuffle_batch(tensor_list, batch_size, capacity, min_after_dequeue, The `capacity` argument controls the how long the prefetching is allowed to grow the queues. + The returned operation is a dequeue operation and will throw + `tf.errors.OutOfRangeError` if the input queue is exhausted. If this + operation is feeding another input queue, its queue runner will catch + this exception, however, if this operation is used in your main thread + you are responsible for catching this yourself. + For example: ```python @@ -445,14 +463,14 @@ def shuffle_batch(tensor_list, batch_size, capacity, min_after_dequeue, """ with ops.op_scope(tensor_list, name, "shuffle_batch") as name: tensor_list = _validate(tensor_list) - dtypes = _dtypes([tensor_list]) + types = _dtypes([tensor_list]) shapes = _shapes([tensor_list], shapes, enqueue_many) queue = data_flow_ops.RandomShuffleQueue( capacity=capacity, min_after_dequeue=min_after_dequeue, seed=seed, - dtypes=dtypes, shapes=shapes) + dtypes=types, shapes=shapes) _enqueue(queue, tensor_list, num_threads, enqueue_many) full = (math_ops.cast(math_ops.maximum(0, queue.size() - min_after_dequeue), - types.float32) * + dtypes.float32) * (1. / (capacity - min_after_dequeue))) # Note that name contains a '/' at the end so we intentionally do not place # a '/' after %s below. @@ -495,10 +513,11 @@ def shuffle_batch_join(tensor_list_list, batch_size, capacity, The `capacity` argument controls the how long the prefetching is allowed to grow the queues. - *N.B.:* You must ensure that either (i) the `shapes` argument is - passed, or (ii) all of the tensors in `tensor_list_list` must have - fully-defined shapes. `ValueError` will be raised if neither of - these conditions holds. + The returned operation is a dequeue operation and will throw + `tf.errors.OutOfRangeError` if the input queue is exhausted. If this + operation is feeding another input queue, its queue runner will catch + this exception, however, if this operation is used in your main thread + you are responsible for catching this yourself. Args: tensor_list_list: A list of tuples of tensors to enqueue. @@ -523,14 +542,14 @@ def shuffle_batch_join(tensor_list_list, batch_size, capacity, with ops.op_scope( _flatten(tensor_list_list), name, "shuffle_batch_join") as name: tensor_list_list = _validate_join(tensor_list_list) - dtypes = _dtypes(tensor_list_list) + types = _dtypes(tensor_list_list) shapes = _shapes(tensor_list_list, shapes, enqueue_many) queue = data_flow_ops.RandomShuffleQueue( capacity=capacity, min_after_dequeue=min_after_dequeue, seed=seed, - dtypes=dtypes, shapes=shapes) + dtypes=types, shapes=shapes) _enqueue_join(queue, tensor_list_list, enqueue_many) full = (math_ops.cast(math_ops.maximum(0, queue.size() - min_after_dequeue), - types.float32) * + dtypes.float32) * (1. / (capacity - min_after_dequeue))) # Note that name contains a '/' at the end so we intentionally do not place # a '/' after %s below. diff --git a/tensorflow/python/training/input_test.py b/tensorflow/python/training/input_test.py index b2629f07407..8cbbd239d3a 100644 --- a/tensorflow/python/training/input_test.py +++ b/tensorflow/python/training/input_test.py @@ -26,6 +26,8 @@ import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow as tf +from tensorflow.python.util import compat + class MatchFilenamesOnceTest(tf.test.TestCase): @@ -36,7 +38,7 @@ class MatchFilenamesOnceTest(tf.test.TestCase): for i in range(3)] for name in additional: open(name, "w").write("Some contents") - filenames += additional + filenames = list(set(filenames + additional)) with self.test_session(): star = tf.train.match_filenames_once( os.path.join(self.get_temp_dir(), "*")) @@ -44,9 +46,9 @@ class MatchFilenamesOnceTest(tf.test.TestCase): os.path.join(self.get_temp_dir(), "match_filenames.?")) one = tf.train.match_filenames_once(additional[1]) tf.initialize_all_variables().run() - self.assertItemsEqual(filenames, star.eval()) - self.assertItemsEqual(additional, question.eval()) - self.assertItemsEqual([additional[1]], one.eval()) + self.assertItemsEqual(map(compat.as_bytes, filenames), star.eval()) + self.assertItemsEqual(map(compat.as_bytes, additional), question.eval()) + self.assertItemsEqual([compat.as_bytes(additional[1])], one.eval()) class LimitEpochsTest(tf.test.TestCase): @@ -64,8 +66,8 @@ class LimitEpochsTest(tf.test.TestCase): love_me = tf.constant("Love Me") love_me_two_times = tf.train.limit_epochs(love_me, num_epochs=2) tf.initialize_all_variables().run() - self.assertEqual("Love Me", love_me_two_times.eval()) - self.assertEqual("Love Me", love_me_two_times.eval()) + self.assertEqual(b"Love Me", love_me_two_times.eval()) + self.assertEqual(b"Love Me", love_me_two_times.eval()) with self.assertRaises(tf.errors.OutOfRangeError): love_me_two_times.eval() @@ -74,7 +76,7 @@ class StringInputProducerTest(tf.test.TestCase): def testNoShuffle(self): with self.test_session(): - strings = ["to", "be", "or", "not", "to", "be"] + strings = [b"to", b"be", b"or", b"not", b"to", b"be"] num_epochs = 3 queue = tf.train.string_input_producer( strings, num_epochs=num_epochs, shuffle=False) @@ -95,7 +97,7 @@ class StringInputProducerTest(tf.test.TestCase): def testShuffle(self): with self.test_session(): - strings = ["a", "b", "c"] + strings = [b"a", b"b", b"c"] num_epochs = 600 queue = tf.train.string_input_producer( strings, num_epochs=num_epochs, shuffle=True, seed=271828) @@ -106,13 +108,13 @@ class StringInputProducerTest(tf.test.TestCase): # Validate that we only shuffle the strings within an epoch and # count how often each possible order appears. - expected = ["abc", "acb", "bac", "bca", "cab", "cba"] + expected = [b"abc", b"acb", b"bac", b"bca", b"cab", b"cba"] frequency = {} for e in expected: frequency[e] = 0 for _ in range(num_epochs): output = dequeue_many.eval() - key = "".join(output) + key = b"".join(output) self.assertIn(key, expected) frequency[key] += 1 @@ -199,7 +201,7 @@ class SliceInputProducerTest(tf.test.TestCase): def testNoShuffle(self): with self.test_session() as sess: num_epochs = 3 - source_strings = ["Alpha", "Beta", "Delta", "Gamma"] + source_strings = [b"Alpha", b"Beta", b"Delta", b"Gamma"] source_ints = [2, 3, 5, 7] slices = tf.train.slice_input_producer( [source_strings, source_ints], num_epochs=num_epochs, shuffle=False) @@ -232,14 +234,14 @@ class SliceInputProducerTest(tf.test.TestCase): # Validate that we only shuffle the integers within an epoch and # count how often each possible order appears. - expected = [",".join(x) for x in - itertools.permutations(["A7", "B3", "D5", "G2"])] + expected = [b",".join(x) for x in + itertools.permutations([b"A7", b"B3", b"D5", b"G2"])] frequency = {} for e in expected: frequency[e] = 0 for _ in range(num_epochs): output = [sess.run(slices) for _ in range(len(source_strings))] - key = ",".join([s + str(i) for s, i in output]) + key = b",".join([s + compat.as_bytes(str(i)) for s, i in output]) self.assertIn(key, expected) frequency[key] += 1 @@ -276,7 +278,7 @@ class BatchTest(tf.test.TestCase): results = sess.run(batched) self.assertAllEqual(results[0], np.arange(i * batch_size, (i + 1) * batch_size)) - self.assertAllEqual(results[1], ["string"] * batch_size) + self.assertAllEqual(results[1], [b"string"] * batch_size) # Reached the limit. with self.assertRaises(tf.errors.OutOfRangeError): @@ -301,7 +303,7 @@ class BatchTest(tf.test.TestCase): results = sess.run(batched) self.assertAllEqual(results[0], np.arange(i * batch_size, (i + 1) * batch_size)) - self.assertAllEqual(results[1], ["string"] * batch_size) + self.assertAllEqual(results[1], [b"string"] * batch_size) # Reached the limit. with self.assertRaises(tf.errors.OutOfRangeError): @@ -327,7 +329,7 @@ class BatchTest(tf.test.TestCase): tf.logging.info("Batch %d: %s", i, results[0]) self.assertEqual(len(results[0]), batch_size) all_counts.extend(results[0]) - self.assertAllEqual(results[1], ["string"] * batch_size) + self.assertAllEqual(results[1], [b"string"] * batch_size) self.assertItemsEqual(all_counts, range(num_batches * batch_size)) # Reached the limit. @@ -369,8 +371,8 @@ class BatchJoinTest(tf.test.TestCase): tf.logging.info("Batch %d: %s", i, results[0]) self.assertEqual(len(results[0]), batch_size) self.assertEqual(len(results[1]), batch_size) - which_a = [i for i, s in enumerate(results[1]) if s == "a"] - which_b = [i for i, s in enumerate(results[1]) if s == "b"] + which_a = [i for i, s in enumerate(results[1]) if s == b"a"] + which_b = [i for i, s in enumerate(results[1]) if s == b"b"] self.assertEqual(len(which_a) + len(which_b), batch_size) if len(which_a) > 0 and len(which_b) > 0: saw_both += 1 all_a.extend([results[0][i] for i in which_a]) @@ -412,7 +414,7 @@ class ShuffleBatchTest(tf.test.TestCase): results = sess.run(batched) self.assertEqual(len(results[0]), batch_size) all_counts.extend(results[0]) - self.assertAllEqual(results[1], ["string"] * batch_size) + self.assertAllEqual(results[1], [b"string"] * batch_size) # Results scrambled, but include all the expected numbers. deltas = [all_counts[i + 1] - all_counts[i] for i in range(len(all_counts) - 1)] @@ -444,7 +446,7 @@ class ShuffleBatchTest(tf.test.TestCase): tf.logging.info("Batch %d: %s", i, results[0]) self.assertEqual(len(results[0]), batch_size) all_counts.extend(results[0]) - self.assertAllEqual(results[1], ["string"] * batch_size) + self.assertAllEqual(results[1], [b"string"] * batch_size) # Results scrambled, but include all the expected numbers. deltas = [all_counts[i + 1] - all_counts[i] for i in range(len(all_counts) - 1)] @@ -492,8 +494,8 @@ class ShuffleBatchJoinTest(tf.test.TestCase): tf.logging.info("Batch %d: %s", i, results[0]) self.assertEqual(len(results[0]), batch_size) self.assertEqual(len(results[1]), batch_size) - which_a = [i for i, s in enumerate(results[1]) if s == "a"] - which_b = [i for i, s in enumerate(results[1]) if s == "b"] + which_a = [i for i, s in enumerate(results[1]) if s == b"a"] + which_b = [i for i, s in enumerate(results[1]) if s == b"b"] self.assertEqual(len(which_a) + len(which_b), batch_size) if len(which_a) > 0 and len(which_b) > 0: saw_both += 1 all_a.extend([results[0][i] for i in which_a]) diff --git a/tensorflow/python/training/learning_rate_decay_test.py b/tensorflow/python/training/learning_rate_decay_test.py index dd597bc84a3..f5f6806f33d 100644 --- a/tensorflow/python/training/learning_rate_decay_test.py +++ b/tensorflow/python/training/learning_rate_decay_test.py @@ -20,8 +20,8 @@ from __future__ import print_function import tensorflow.python.platform +from tensorflow.python.framework import dtypes from tensorflow.python.framework import test_util -from tensorflow.python.framework import types from tensorflow.python.ops import state_ops from tensorflow.python.ops import variables from tensorflow.python.platform import googletest @@ -39,7 +39,7 @@ class LRDecayTest(test_util.TensorFlowTestCase): def testStaircase(self): with self.test_session(): - step = state_ops.variable_op([], types.int32) + step = state_ops.variable_op([], dtypes.int32) assign_100 = state_ops.assign(step, 100) assign_1 = state_ops.assign(step, 1) assign_2 = state_ops.assign(step, 2) diff --git a/tensorflow/python/training/moving_averages.py b/tensorflow/python/training/moving_averages.py index 5666dc17203..1b9b6401862 100644 --- a/tensorflow/python/training/moving_averages.py +++ b/tensorflow/python/training/moving_averages.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import types from tensorflow.python.ops import array_ops from tensorflow.python.ops import constant_op from tensorflow.python.ops import control_flow_ops @@ -205,7 +205,7 @@ class ExponentialMovingAverage(object): if var_list is None: var_list = variables.trainable_variables() for var in var_list: - if var.dtype.base_dtype not in [types.float32, types.float64]: + if var.dtype.base_dtype not in [dtypes.float32, dtypes.float64]: raise TypeError("The variables must be float or double: %s" % var) if var in self._averages: raise ValueError("Moving average already computed for: %s" % var) @@ -228,7 +228,7 @@ class ExponentialMovingAverage(object): with ops.name_scope(self._name) as scope: decay = ops.convert_to_tensor(self._decay, name="decay") if self._num_updates is not None: - num_updates = math_ops.cast(self._num_updates, types.float32, + num_updates = math_ops.cast(self._num_updates, dtypes.float32, name="num_updates") decay = math_ops.minimum(decay, (1.0 + num_updates) / (10.0 + num_updates)) diff --git a/tensorflow/python/training/moving_averages_test.py b/tensorflow/python/training/moving_averages_test.py index cf4e84d403c..a2ad3a51b0e 100644 --- a/tensorflow/python/training/moving_averages_test.py +++ b/tensorflow/python/training/moving_averages_test.py @@ -21,9 +21,9 @@ from __future__ import print_function import tensorflow.python.platform from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util -from tensorflow.python.framework import types from tensorflow.python.ops import constant_op from tensorflow.python.ops import state_ops from tensorflow.python.ops import variables @@ -36,7 +36,7 @@ class MovingAveragesTest(test_util.TensorFlowTestCase): def testAssignMovingAverage(self): with self.test_session(): var = variables.Variable([10.0, 11.0]) - val = constant_op.constant([1.0, 2.0], types.float32) + val = constant_op.constant([1.0, 2.0], dtypes.float32) decay = 0.25 assign = moving_averages.assign_moving_average(var, val, decay) variables.initialize_all_variables().run() @@ -151,7 +151,7 @@ class ExponentialMovingAverageTest(test_util.TensorFlowTestCase): with ops.device("dev_v0"): v0 = variables.Variable(10.0, name="v0") with ops.device("dev_v1"): - v1 = state_ops.variable_op(shape=[1], dtype=types.float32, name="v1") + v1 = state_ops.variable_op(shape=[1], dtype=dtypes.float32, name="v1") tensor2 = v0 + v1 ema = moving_averages.ExponentialMovingAverage(0.25, name="foo_avg") with ops.device("default"): diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py index 18ac5b8bac9..dc2f700f816 100644 --- a/tensorflow/python/training/optimizer.py +++ b/tensorflow/python/training/optimizer.py @@ -20,8 +20,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import types as tf_types from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients @@ -38,7 +38,7 @@ class Optimizer(object): ### Usage - ``` + ```python # Create an optimizer with the desired parameters. opt = GradientDescentOptimizer(learning_rate=0.1) # Add Ops to the graph to minimize a cost by updating a list of variables. @@ -49,7 +49,7 @@ class Optimizer(object): In the training program you will just have to run the returned Op. - ``` + ```python # Execute opt_op to do one step of training: opt_op.run() ``` @@ -66,7 +66,7 @@ class Optimizer(object): Example: - ``` + ```python # Create an optimizer. opt = GradientDescentOptimizer(learning_rate=0.1) @@ -95,18 +95,18 @@ class Optimizer(object): The possible values are: `GATE_NONE`, `GATE_OP`, and `GATE_GRAPH`. - <b>GATE_NONE</b>: Compute and apply gradients in parallel. This provides the - maximum parallelism in execution, at the cost of some non-reproducibility in - the results. For example the two gradients of MatMul depend on the input + <b>`GATE_NONE`</b>: Compute and apply gradients in parallel. This provides + the maximum parallelism in execution, at the cost of some non-reproducibility + in the results. For example the two gradients of `matmul` depend on the input values: With `GATE_NONE` one of the gradients could be applied to one of the inputs _before_ the other gradient is computed resulting in non-reproducible results. - <b>GATE_OP</b>: For each Op, make sure all gradients are computed before they - are used. This prevents race conditions for Ops that generate gradients for - multiple inputs where the gradients depend on the inputs. + <b>`GATE_OP`</b>: For each Op, make sure all gradients are computed before + they are used. This prevents race conditions for Ops that generate gradients + for multiple inputs where the gradients depend on the inputs. - <b>GATE_GRAPH</b>: Make sure all gradients for all variables are computed + <b>`GATE_GRAPH`</b>: Make sure all gradients for all variables are computed before any one of them is used. This provides the least parallelism but can be useful if you want to process all gradients before applying any of them. @@ -154,32 +154,32 @@ class Optimizer(object): def minimize(self, loss, global_step=None, var_list=None, gate_gradients=GATE_OP, aggregation_method=None, name=None): - """Add operations to minimize 'loss' by updating 'var_list'. + """Add operations to minimize `loss` by updating `var_list`. - This method simply combines calls compute_gradients() and - apply_gradients(). If you want to process the gradient before applying them - call compute_gradients() and apply_gradients() explicitly instead of using - this function. + This method simply combines calls `compute_gradients()` and + `apply_gradients()`. If you want to process the gradient before applying + them call `compute_gradients()` and `apply_gradients()` explicitly instead + of using this function. Args: - loss: A Tensor containing the value to minimize. - global_step: Optional Variable to increment by one after the + loss: A `Tensor` containing the value to minimize. + global_step: Optional `Variable` to increment by one after the variables have been updated. - var_list: Optional list of variables.Variable to update to minimize - 'loss'. Defaults to the list of variables collected in the graph - under the key GraphKeys.TRAINABLE_VARIABLES. + var_list: Optional list of `Variable` objects to update to minimize + `loss`. Defaults to the list of variables collected in the graph + under the key `GraphKeys.TRAINABLE_VARIABLES`. gate_gradients: How to gate the computation of gradients. Can be - GATE_NONE, GATE_OP, or GATE_GRAPH. + `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`. aggregation_method: Specifies the method used to combine gradient terms. Valid values are defined in the class `AggregationMethod`. name: Optional name for the returned operation. Returns: - An Operation that updates the variables in 'var_list'. If 'global_step' - was not None, that operation also increments global_step. + An Operation that updates the variables in `var_list`. If `global_step` + was not `None`, that operation also increments `global_step`. Raises: - ValueError: if some of the variables are not variables.Variable objects. + ValueError: if some of the variables are not `Variable` objects. """ grads_and_vars = self.compute_gradients( loss, var_list=var_list, gate_gradients=gate_gradients, @@ -189,21 +189,21 @@ class Optimizer(object): def compute_gradients(self, loss, var_list=None, gate_gradients=GATE_OP, aggregation_method=None): - """Compute gradients of "loss" for the variables in "var_list". + """Compute gradients of `loss` for the variables in `var_list`. - This is the first part of minimize(). It returns a list + This is the first part of `minimize()`. It returns a list of (gradient, variable) pairs where "gradient" is the gradient - for "variable". Note that "gradient" can be a Tensor, a - IndexedSlices, or None if there is no gradient for the + for "variable". Note that "gradient" can be a `Tensor`, an + `IndexedSlices`, or `None` if there is no gradient for the given variable. Args: loss: A Tensor containing the value to minimize. var_list: Optional list of variables.Variable to update to minimize - "loss". Defaults to the list of variables collected in the graph - under the key GraphKey.TRAINABLE_VARIABLES. + `loss`. Defaults to the list of variables collected in the graph + under the key `GraphKey.TRAINABLE_VARIABLES`. gate_gradients: How to gate the computation of gradients. Can be - GATE_NONE, GATE_OP, or GATE_GRAPH. + `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`. aggregation_method: Specifies the method used to combine gradient terms. Valid values are defined in the class `AggregationMethod`. @@ -211,7 +211,7 @@ class Optimizer(object): A list of (gradient, variable) pairs. Raises: - TypeError: If var_list contains anything else than variables.Variable. + TypeError: If `var_list` contains anything else than `Variable` objects. ValueError: If some arguments are invalid. """ if gate_gradients not in [Optimizer.GATE_NONE, Optimizer.GATE_OP, @@ -237,27 +237,28 @@ class Optimizer(object): def apply_gradients(self, grads_and_vars, global_step=None, name=None): """Apply gradients to variables. - This is the second part of minimize(). It returns an Operation that + This is the second part of `minimize()`. It returns an `Operation` that applies gradients. Args: grads_and_vars: List of (gradient, variable) pairs as returned by - compute_gradients(). - global_step: Optional Variable to increment by one after the + `compute_gradients()`. + global_step: Optional `Variable` to increment by one after the variables have been updated. name: Optional name for the returned operation. Default to the - name passed to the Optimizer constructor. + name passed to the `Optimizer` constructor. Returns: - An Operation that applies the specified gradients. If 'global_step' - was not None, that operation also increments global_step. + An `Operation` that applies the specified gradients. If `global_step` + was not None, that operation also increments `global_step`. Raises: - TypeError: if grads_and_vars is malformed. + TypeError: if `grads_and_vars` is malformed. """ # This is a default implementation of apply_gradients() that can be shared # by most optimizers. It relies on the subclass implementing the following # methods: _create_slots(), _prepare(), _apply_dense(), and _apply_sparse(). + grads_and_vars = tuple(grads_and_vars) # Make sure repeat iteration works for g, v in grads_and_vars: if not isinstance(g, (ops.Tensor, ops.IndexedSlices, type(None))): raise TypeError( @@ -287,20 +288,21 @@ class Optimizer(object): return state_ops.assign_add(global_step, 1, name=name).op def get_slot(self, var, name): - """Return a slot named "name" created for "var" by the Optimizer. + """Return a slot named `name` created for `var` by the Optimizer. - Some Optimizer subclasses use additional variables. For example - Momentum and Adagrad use variables to accumulate updates. This method - gives access to these Variables if for some reason you need them. + Some `Optimizer` subclasses use additional variables. For example + `Momentum` and `Adagrad` use variables to accumulate updates. This method + gives access to these `Variable` objects if for some reason you need them. - Use get_slot_names() to get the list of slot names created by the Optimizer. + Use `get_slot_names()` to get the list of slot names created by the + `Optimizer`. Args: - var: A variable passed to minimize() or apply_gradients(). + var: A variable passed to `minimize()` or `apply_gradients()`. name: A string. Returns: - The Variable for the slot if it was created, None otherwise. + The `Variable` for the slot if it was created, `None` otherwise. """ named_slots = self._slots.get(name, None) if not named_slots: @@ -308,9 +310,9 @@ class Optimizer(object): return named_slots.get(var, None) def get_slot_names(self): - """Return a list of the names of slots created by the Optimizer. + """Return a list of the names of slots created by the `Optimizer`. - See get_slot(). + See `get_slot()`. Returns: A list of strings. @@ -318,7 +320,7 @@ class Optimizer(object): return sorted(self._slots.keys()) def _assert_valid_dtypes(self, tensors): - """Asserts tensors are all valid types (see _valid_dtypes). + """Asserts tensors are all valid types (see `_valid_dtypes`). Args: tensors: tensors to check. @@ -340,18 +342,18 @@ class Optimizer(object): def _valid_dtypes(self): """Valid types for loss, variables and gradients. - Defaults to float32. Subclasses should override to allow other types. + Defaults to `float32`. Subclasses should override to allow other types. Returns: Valid types for loss, variables and gradients. """ - return set([tf_types.float32]) + return set([dtypes.float32]) def _create_slots(self, var_list): """Create all slots needed by the variables. Args: - var_list: A list of variables.Variable. + var_list: A list of `Variable` objects. """ # No slots needed by default pass @@ -365,38 +367,39 @@ class Optimizer(object): pass def _apply_dense(self, grad, var): - """Add ops to apply dense gradients to "var". + """Add ops to apply dense gradients to `var`. Args: - grad: A Tensor. - var: A variables.Variable. + grad: A `Tensor`. + var: A `Variable` object. Return: - An Operation. + An `Operation`. """ raise NotImplementedError() def _apply_sparse(self, grad, var): - """Add ops to apply sparse gradients to "var". + """Add ops to apply sparse gradients to `var`. Args: - grad: IndexedSlices. - var: A variables.Variable. + grad: `IndexedSlices`. + var: A `Variable` object. Return: - An Operation. + An `Operation`. """ raise NotImplementedError() def _finish(self, update_ops, name_scope): """Do what is needed to finish the update. - This is called with the name_scope using the "name" that + This is called with the `name_scope` using the "name" that users have chosen for the application of gradients. Args: - update_ops: List of Operations to update variables. This list contains - the values returned by the _apply_dense() and _apply_sparse() calls. + update_ops: List of `Operation` objects to update variables. This list + contains the values returned by the `_apply_dense()` and + `_apply_sparse()` calls. name_scope: string. Name to use for the returned operation. Returns: @@ -412,14 +415,14 @@ class Optimizer(object): """Find or create a slot for a variable. Args: - var: A variables.Variable. - val: A Tensor. The initial value of the slot. + var: A `Variable` object. + val: A `Tensor`. The initial value of the slot. slot_name: Name for the slot. op_name: Name to use when scoping the Variable that needs to be created for the slot. Returns: - A variables.Variable. + A `Variable` object. """ named_slots = self._slots.get(slot_name, None) if named_slots is None: @@ -439,13 +442,13 @@ class Optimizer(object): """Find or create a slot initialized with 0.0. Args: - var: A variables.Variable. + var: A `Variable` object. slot_name: Name for the slot. op_name: Name to use when scoping the Variable that needs to be created for the slot. Returns: - A variables.Variable. + A `Variable` object. """ val = array_ops.zeros(var.get_shape().as_list(), dtype=var.dtype) return self._get_or_make_slot(var, val, slot_name, op_name) diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index 37e33213c69..4b86a08609d 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -20,10 +20,10 @@ from __future__ import division from __future__ import print_function import collections -import numbers import os.path import time +import numpy as np import six from google.protobuf import text_format @@ -44,6 +44,7 @@ from tensorflow.python.platform import logging from tensorflow.python.training import saver_pb2 from tensorflow.python.training import training_util from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState +from tensorflow.python.util import compat class BaseSaverBuilder(object): @@ -516,13 +517,13 @@ def get_checkpoint_state(checkpoint_dir, latest_filename=None): checkpoint_dir, latest_filename) f = None try: - # Check that the file exists before opeining it to avoid + # Check that the file exists before opening it to avoid # many lines of errors from colossus in the logs. if gfile.Exists(coord_checkpoint_filename): f = gfile.FastGFile(coord_checkpoint_filename, mode="r") ckpt = CheckpointState() text_format.Merge(f.read(), ckpt) - except gfile.FileError: + except IOError: # It's ok if the file cannot be read return None except text_format.ParseError as e: @@ -657,39 +658,39 @@ class Saver(object): saver = tf.train.Saver({v.op.name: v for v in [v1, v2]}) ``` - The optional `reshape` argument, if True, allows restoring a variable from + The optional `reshape` argument, if `True`, allows restoring a variable from a save file where the variable had a different shape, but the same number of elements and type. This is useful if you have reshaped a variable and want to reload it from an older checkpoint. - The optional `sharded` argument, if True, instructs the saver to shard + The optional `sharded` argument, if `True`, instructs the saver to shard checkpoints per device. Args: - var_list: A list of Variables or a dictionary mapping names to - Variables. If None, defaults to the list of all variables. - reshape: If True, allows restoring parameters from a checkpoint + var_list: A list of `Variable` objects or a dictionary mapping names to + variables. If `None`, defaults to the list of all variables. + reshape: If `True`, allows restoring parameters from a checkpoint where the variables have a different shape. - sharded: If True, shard the checkpoints, one per device. + sharded: If `True`, shard the checkpoints, one per device. max_to_keep: maximum number of recent checkpoints to keep. Defaults to 10,000 hours. keep_checkpoint_every_n_hours: How often to keep checkpoints. Defaults to 10,000 hours. name: string. Optional name to use as a prefix when adding operations. - restore_sequentially: A Bool, which if true, causes restore of different + restore_sequentially: A `Bool`, which if true, causes restore of different variables to happen sequentially within each device. This can lower memory usage when restoring very large models. - saver_def: Optional SaverDef proto to use instead of running the builder. - This is only useful for specialty code that wants to recreate a Saver - object for a previously built Graph that had a Saver. The saver_def - proto should be the one returned by the as_saver_def() call of the - Saver that was created for that Graph. - builder: Optional SaverBuilder to use if a saver_def was not provided. - Defaults to BaseSaverBuilder(). + saver_def: Optional `SaverDef` proto to use instead of running the + builder. This is only useful for specialty code that wants to recreate + a `Saver` object for a previously built `Graph` that had a `Saver`. + The `saver_def` proto should be the one returned by the + `as_saver_def()` call of the `Saver` that was created for that `Graph`. + builder: Optional `SaverBuilder` to use if a `saver_def` was not provided. + Defaults to `BaseSaverBuilder()`. Raises: TypeError: If `var_list` is invalid. - ValueError: If any of the keys or values in `var_list` is not unique. + ValueError: If any of the keys or values in `var_list` are not unique. """ if saver_def is None: if builder is None: @@ -728,26 +729,25 @@ class Saver(object): self._last_checkpoints = [] def _CheckpointFilename(self, p): - """Returns the checkpoint file name. - - If p is (filename, time) pair, return p[0]; else return p. + """Returns the checkpoint filename given a `(filename, time)` pair. Args: - p: (filename, time) pair or just checkpoint filename. + p: (filename, time) pair. Returns: Checkpoint file name. """ - return p[0] if isinstance(p, tuple) else p + name, _ = p + return name def _MaybeDeleteOldCheckpoints(self, latest_save_path): """Deletes old checkpoints if necessary. - Always keep the last max_to_keep checkpoints. If - keep_checkpoint_every_n_hours was specified, keep an additional checkpoint - every N hours. For example, if N is 0.5, an additional checkpoint is kept - for every 0.5 hours of training; if N is 10, an additional checkpoint is - kept for every 10 hours of training. + Always keep the last `max_to_keep` checkpoints. If + `keep_checkpoint_every_n_hours` was specified, keep an additional checkpoint + every `N` hours. For example, if `N` is 0.5, an additional checkpoint is + kept for every 0.5 hours of training; if `N` is 10, an additional + checkpoint is kept for every 10 hours of training. Args: latest_save_path: Name including path of checkpoint file to save. @@ -774,7 +774,7 @@ class Saver(object): for f in gfile.Glob(self._CheckpointFilename(p)): try: gfile.Remove(f) - except gfile.GOSError as e: + except OSError as e: logging.warning("Ignoring: %s", str(e)) def as_saver_def(self): @@ -803,17 +803,20 @@ class Saver(object): return list(self._CheckpointFilename(p) for p in self._last_checkpoints) def set_last_checkpoints(self, last_checkpoints): - """Sets the list of not-yet-deleted checkpoint filenames. + """Sets the list of old checkpoint filenames. Args: - last_checkpoints: a list of checkpoint filenames. + last_checkpoints: A list of checkpoint filenames. Raises: - AssertionError: if the list of checkpoint filenames has already been set. + AssertionError: If the list of checkpoint filenames has already been set. """ assert not self._last_checkpoints assert isinstance(last_checkpoints, list) - self._last_checkpoints = list(last_checkpoints) + # We use a timestamp of +inf so that this checkpoint will never be + # deleted. This is both safe and backwards compatible to a previous + # version of the code which used s[1] as the "timestamp". + self._last_checkpoints = [(s, np.inf) for s in last_checkpoints] def save(self, sess, save_path, global_step=None, latest_filename=None): """Saves variables. @@ -831,7 +834,7 @@ class Saver(object): `sharded`, this is the prefix of the sharded checkpoint filename. global_step: If provided the global step number is appended to `save_path` to create the checkpoint filename. The optional argument - can be a Tensor, a Tensor name or an integer. + can be a `Tensor`, a `Tensor` name or an integer. latest_filename: Optional name for the protocol buffer file that will contains the list of most recent checkpoint filenames. That file, kept in the same directory as the checkpoint files, is automatically @@ -844,12 +847,12 @@ class Saver(object): is the number of shards created. Raises: - TypeError: If `sess` is not a Session. + TypeError: If `sess` is not a `Session`. """ if latest_filename is None: latest_filename = "checkpoint" if global_step is not None: - if not isinstance(global_step, numbers.Number): + if not isinstance(global_step, compat.integral_types): global_step = training_util.global_step(sess, global_step) checkpoint_file = "%s-%d" % (save_path, global_step) else: @@ -860,7 +863,7 @@ class Saver(object): model_checkpoint_path = sess.run( self._save_tensor_name, {self._filename_tensor_name: checkpoint_file}) - model_checkpoint_path = str(model_checkpoint_path) + model_checkpoint_path = compat.as_str(model_checkpoint_path) self._MaybeDeleteOldCheckpoints(model_checkpoint_path) update_checkpoint_state(save_path, model_checkpoint_path, self.last_checkpoints, latest_filename) @@ -878,7 +881,7 @@ class Saver(object): `save()` call, or a call to `latest_checkpoint()`. Args: - sess: A Session to use to restore the parameters. + sess: A `Session` to use to restore the parameters. save_path: Path where parameters were previously saved. """ sess.run([self._restore_op_name], {self._filename_tensor_name: save_path}) @@ -894,7 +897,7 @@ def latest_checkpoint(checkpoint_dir, latest_filename=None): See the corresponding argument to `Saver.save()`. Returns: - The full path to the latest checkpoint or None if no checkpoint was found. + The full path to the latest checkpoint or `None` if no checkpoint was found. """ # Pick the latest checkpoint based on checkpoint state. ckpt = get_checkpoint_state(checkpoint_dir, latest_filename) diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py index e490ed78d42..c586831766b 100644 --- a/tensorflow/python/training/saver_test.py +++ b/tensorflow/python/training/saver_test.py @@ -343,7 +343,7 @@ class MaxToKeepTest(tf.test.TestCase): save_dir = os.path.join(self.get_temp_dir(), "max_to_keep_non_sharded") try: gfile.DeleteRecursively(save_dir) - except gfile.GOSError as _: + except OSError: pass # Ignore gfile.MakeDirs(save_dir) @@ -432,7 +432,7 @@ class MaxToKeepTest(tf.test.TestCase): save_dir = os.path.join(self.get_temp_dir(), "max_to_keep_sharded") try: gfile.DeleteRecursively(save_dir) - except gfile.GOSError as _: + except OSError: pass # Ignore gfile.MakeDirs(save_dir) @@ -470,7 +470,7 @@ class KeepCheckpointEveryNHoursTest(tf.test.TestCase): "keep_checkpoint_every_n_hours") try: gfile.DeleteRecursively(save_dir) - except gfile.GOSError as _: + except OSError: pass # Ignore gfile.MakeDirs(save_dir) diff --git a/tensorflow/python/training/summary_io.py b/tensorflow/python/training/summary_io.py index 4d9868ef00b..f10e984e6d7 100644 --- a/tensorflow/python/training/summary_io.py +++ b/tensorflow/python/training/summary_io.py @@ -20,7 +20,6 @@ from __future__ import division from __future__ import print_function import os.path -import Queue import threading import time @@ -31,6 +30,7 @@ from tensorflow.core.util import event_pb2 from tensorflow.python import pywrap_tensorflow from tensorflow.python.lib.io import tf_record from tensorflow.python.platform import gfile +from tensorflow.python.util import compat class SummaryWriter(object): @@ -93,9 +93,9 @@ class SummaryWriter(object): self._logdir = logdir if not gfile.IsDirectory(self._logdir): gfile.MakeDirs(self._logdir) - self._event_queue = Queue.Queue(max_queue) + self._event_queue = six.moves.queue.Queue(max_queue) self._ev_writer = pywrap_tensorflow.EventsWriter( - os.path.join(self._logdir, "events")) + compat.as_bytes(os.path.join(self._logdir, "events"))) self._worker = _EventLoggerThread(self._event_queue, self._ev_writer, flush_secs) self._worker.start() @@ -120,7 +120,7 @@ class SummaryWriter(object): global_step: Number. Optional global step value to record with the summary. """ - if isinstance(summary, six.binary_type): + if isinstance(summary, bytes): summ = summary_pb2.Summary() summ.ParseFromString(summary) summary = summ diff --git a/tensorflow/python/training/training_ops_test.py b/tensorflow/python/training/training_ops_test.py index 9e93562ed72..d18fd0fdf32 100644 --- a/tensorflow/python/training/training_ops_test.py +++ b/tensorflow/python/training/training_ops_test.py @@ -24,8 +24,9 @@ import itertools import tensorflow.python.platform import numpy as np +import tensorflow as tf -from tensorflow.python.framework import types +from tensorflow.python.framework import dtypes from tensorflow.python.framework.test_util import TensorFlowTestCase from tensorflow.python.ops import constant_op from tensorflow.python.ops import variables @@ -37,13 +38,13 @@ class TrainingOpsTest(TensorFlowTestCase): def _toType(self, dtype): if dtype == np.float32: - return types.float32 + return tf.float32 elif dtype == np.float64: - return types.float64 + return tf.float64 elif dtype == np.int32: - return types.int32 + return tf.int32 elif dtype == np.int64: - return types.int64 + return tf.int64 else: assert False, (dtype) diff --git a/tensorflow/python/training/training_util.py b/tensorflow/python/training/training_util.py index 374b863220d..3d2b6d9c84b 100644 --- a/tensorflow/python/training/training_util.py +++ b/tensorflow/python/training/training_util.py @@ -52,7 +52,7 @@ def global_step(sess, global_step_tensor): def write_graph(graph_def, logdir, name, as_text=True): """Writes a graph proto on disk. - The graph is written as a binary proto unless as_text is `True`. + The graph is written as a binary proto unless `as_text` is `True`. ```python v = tf.Variable(0, name='my_variable') diff --git a/tensorflow/python/util/compat.py b/tensorflow/python/util/compat.py new file mode 100644 index 00000000000..13134fae14d --- /dev/null +++ b/tensorflow/python/util/compat.py @@ -0,0 +1,91 @@ +# Copyright 2015 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Functions for Python 2 vs. 3 compatibility.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numbers +import numpy as np +import six + + +def as_bytes(bytes_or_text): + """Returns the given argument as a byte array. + + NOTE(mrry): For Python 2 and 3 compatibility, we convert all string + arguments to SWIG methods into byte arrays. Unicode strings are + encoded as UTF-8; however the valid arguments for all of the + human-readable arguments must currently be a subset of ASCII. + + Args: + bytes_or_text: A `unicode`, `string`, or `bytes` object. + + Returns: + A `bytes` object. + + Raises: + TypeError: If `bytes_or_text` is not a binary or unicode string. + """ + if isinstance(bytes_or_text, six.text_type): + return bytes_or_text.encode('utf-8') + elif isinstance(bytes_or_text, bytes): + return bytes_or_text + else: + raise TypeError('Expected binary or unicode string, got %r' % bytes_or_text) + + +def as_text(bytes_or_text): + """Returns the given argument as a unicode string. + + NOTE(mrry): For Python 2 and 3 compatibility, we interpret all + returned strings from SWIG methods as byte arrays. This function + converts those strings that are intended to be human-readable into + UTF-8 unicode strings. + + Args: + bytes_or_text: A `unicode`, `string`, or `bytes` object. + + Returns: + A `unicode` (Python 2) or `str` (Python 3) object. + + Raises: + TypeError: If `bytes_or_text` is not a binary or unicode string. + """ + if isinstance(bytes_or_text, six.text_type): + return bytes_or_text + elif isinstance(bytes_or_text, bytes): + return bytes_or_text.decode('utf-8') + else: + raise TypeError('Expected binary or unicode string, got %r' % bytes_or_text) + + +# Convert an object to a `str` in both Python 2 and 3 +if six.PY2: + as_str = as_bytes +else: + as_str = as_text + + +# Numpy 1.8 scalars don't inherit from numbers.Integral in Python 3, so we +# need to check them specifically. The same goes from Real and Complex. +integral_types = (numbers.Integral, np.integer) +real_types = (numbers.Real, np.integer, np.floating) +complex_types = (numbers.Complex, np.number) + + +# Either bytes or text +bytes_or_text_types = (bytes, six.text_type) diff --git a/tensorflow/python/util/port.i b/tensorflow/python/util/port.i index fdb217dcc78..568658cd7f2 100644 --- a/tensorflow/python/util/port.i +++ b/tensorflow/python/util/port.i @@ -1,3 +1,18 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + %include "tensorflow/python/platform/base.i" %{ diff --git a/tensorflow/stream_executor/dnn.cc b/tensorflow/stream_executor/dnn.cc index 59a0e2cb623..97a15dd263c 100644 --- a/tensorflow/stream_executor/dnn.cc +++ b/tensorflow/stream_executor/dnn.cc @@ -137,6 +137,7 @@ string BatchDescriptor::ToShortString() const { return port::StrCat(batch, depth, y, x, suffix); default: LOG(FATAL) << "Unknown layout " << static_cast<int32>(layout()); + return ""; // Avoid lack-of-return warning } } @@ -204,6 +205,7 @@ string FilterDescriptor::ToShortString() const { return port::StrCat(y, x, id, od); default: LOG(FATAL) << "Unknown layout " << static_cast<int32>(layout_); + return ""; // Avoid lack-of-return warning } } diff --git a/tensorflow/tensorboard/.gitignore b/tensorflow/tensorboard/.gitignore new file mode 100644 index 00000000000..9f4cfe9b129 --- /dev/null +++ b/tensorflow/tensorboard/.gitignore @@ -0,0 +1,15 @@ +bower_components/* +node_modules/* +typings/* +build/* +dist/tf-tensorboard-demo.html + +components/tf-graph/demo/tf_model_zoo/* + +# Js files in the graph visualizer migrated to typescript. These files +# are produced by the compiler and should not be submitted to the repo. +components/tf-graph-common/lib/*.js +components/tf-graph-common/lib/scene/*.js +components/tf-event-dashboard/*.js +components/tf-categorizer/*.js +components/tf-dashboard-common/*.js diff --git a/tensorflow/tensorboard/app/analytics.js b/tensorflow/tensorboard/app/analytics.js index e69de29bb2d..b61a151ac85 100644 --- a/tensorflow/tensorboard/app/analytics.js +++ b/tensorflow/tensorboard/app/analytics.js @@ -0,0 +1,16 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Nothing to see here. vulcanize doesn't like empty files. diff --git a/tensorflow/tensorboard/components/tf-categorizer/categorizer.ts b/tensorflow/tensorboard/components/tf-categorizer/categorizer.ts index e05078279e2..4a713a7783f 100644 --- a/tensorflow/tensorboard/components/tf-categorizer/categorizer.ts +++ b/tensorflow/tensorboard/components/tf-categorizer/categorizer.ts @@ -1,3 +1,18 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + /// <reference path="../../typings/tsd.d.ts" /> module Categorizer { diff --git a/tensorflow/tensorboard/components/tf-categorizer/test/categorizerTest.ts b/tensorflow/tensorboard/components/tf-categorizer/test/categorizerTest.ts index be09c56c410..360d644b423 100644 --- a/tensorflow/tensorboard/components/tf-categorizer/test/categorizerTest.ts +++ b/tensorflow/tensorboard/components/tf-categorizer/test/categorizerTest.ts @@ -1,3 +1,18 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + /// <reference path="../../../typings/tsd.d.ts" /> /// <reference path="../categorizer.ts" /> var assert = chai.assert; diff --git a/tensorflow/tensorboard/components/tf-dashboard-common/urlGenerator.ts b/tensorflow/tensorboard/components/tf-dashboard-common/urlGenerator.ts index c7bbcbf4340..cb8c97823f0 100644 --- a/tensorflow/tensorboard/components/tf-dashboard-common/urlGenerator.ts +++ b/tensorflow/tensorboard/components/tf-dashboard-common/urlGenerator.ts @@ -1,3 +1,18 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + /// <reference path="../../typings/tsd.d.ts" /> /// <reference path="../../bower_components/plottable/plottable.d.ts" /> diff --git a/tensorflow/tensorboard/components/tf-event-dashboard/dataCoordinator.ts b/tensorflow/tensorboard/components/tf-event-dashboard/dataCoordinator.ts index c489eca17cf..5083603125a 100644 --- a/tensorflow/tensorboard/components/tf-event-dashboard/dataCoordinator.ts +++ b/tensorflow/tensorboard/components/tf-event-dashboard/dataCoordinator.ts @@ -1,3 +1,18 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + /// <reference path="../../typings/tsd.d.ts" /> /// <reference path="../../bower_components/plottable/plottable.d.ts" /> diff --git a/tensorflow/tensorboard/components/tf-event-dashboard/dataset.ts b/tensorflow/tensorboard/components/tf-event-dashboard/dataset.ts index de814583d38..ea2fd950e5d 100644 --- a/tensorflow/tensorboard/components/tf-event-dashboard/dataset.ts +++ b/tensorflow/tensorboard/components/tf-event-dashboard/dataset.ts @@ -1,3 +1,18 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + /// <reference path="../../typings/tsd.d.ts" /> /// <reference path="../../bower_components/plottable/plottable.d.ts" /> diff --git a/tensorflow/tensorboard/components/tf-event-dashboard/dragZoomInteraction.ts b/tensorflow/tensorboard/components/tf-event-dashboard/dragZoomInteraction.ts index bf9f7b70e27..c65fb092b2e 100644 --- a/tensorflow/tensorboard/components/tf-event-dashboard/dragZoomInteraction.ts +++ b/tensorflow/tensorboard/components/tf-event-dashboard/dragZoomInteraction.ts @@ -1,3 +1,18 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + module Plottable { export class DragZoomLayer extends Components.SelectionBoxLayer { private _dragInteraction: Interactions.Drag; diff --git a/tensorflow/tensorboard/components/tf-event-dashboard/tf-chart.ts b/tensorflow/tensorboard/components/tf-event-dashboard/tf-chart.ts index 4ea7fdf83c2..349b50001b0 100644 --- a/tensorflow/tensorboard/components/tf-event-dashboard/tf-chart.ts +++ b/tensorflow/tensorboard/components/tf-event-dashboard/tf-chart.ts @@ -1,3 +1,18 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + /// <reference path="../../typings/tsd.d.ts" /> /// <reference path="../../bower_components/plottable/plottable.d.ts" /> diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/colors.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/colors.ts index 8912483c09e..2c699d27490 100644 --- a/tensorflow/tensorboard/components/tf-graph-common/lib/colors.ts +++ b/tensorflow/tensorboard/components/tf-graph-common/lib/colors.ts @@ -1,3 +1,18 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + module tf { /** diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/common.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/common.ts index ed148bf719c..17b753a2f9a 100644 --- a/tensorflow/tensorboard/components/tf-graph-common/lib/common.ts +++ b/tensorflow/tensorboard/components/tf-graph-common/lib/common.ts @@ -1,3 +1,18 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + /// <reference path="../../../typings/tsd.d.ts" /> declare module graphlib { diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/graph.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/graph.ts index ac3a5211630..8d1faaceecc 100644 --- a/tensorflow/tensorboard/components/tf-graph-common/lib/graph.ts +++ b/tensorflow/tensorboard/components/tf-graph-common/lib/graph.ts @@ -1,3 +1,18 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + /// <reference path="../../../typings/tsd.d.ts" /> /// <reference path="common.ts" /> module tf.graph { diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/hierarchy.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/hierarchy.ts index 17f7a0b4053..1c8e1b2e182 100644 --- a/tensorflow/tensorboard/components/tf-graph-common/lib/hierarchy.ts +++ b/tensorflow/tensorboard/components/tf-graph-common/lib/hierarchy.ts @@ -1,3 +1,18 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + /// <reference path="graph.ts" /> /// <reference path="template.ts" /> diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/layout.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/layout.ts index 4eb3cab0116..5a0559627f2 100644 --- a/tensorflow/tensorboard/components/tf-graph-common/lib/layout.ts +++ b/tensorflow/tensorboard/components/tf-graph-common/lib/layout.ts @@ -1,3 +1,18 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + /// <reference path="graph.ts" /> /// <reference path="render.ts" /> diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/parser.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/parser.ts index b4864738a96..0b8a308fe35 100644 --- a/tensorflow/tensorboard/components/tf-graph-common/lib/parser.ts +++ b/tensorflow/tensorboard/components/tf-graph-common/lib/parser.ts @@ -1,3 +1,18 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + /// <reference path="../../../typings/tsd.d.ts" /> /// <reference path="common.ts" /> module tf.graph.parser { diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/render.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/render.ts index 50216bce226..81e1873137f 100644 --- a/tensorflow/tensorboard/components/tf-graph-common/lib/render.ts +++ b/tensorflow/tensorboard/components/tf-graph-common/lib/render.ts @@ -1,3 +1,18 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + /// <reference path="graph.ts" /> /// <reference path="hierarchy.ts" /> diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/annotation.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/annotation.ts index 7f6c9ff4b64..6823d8b48c7 100644 --- a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/annotation.ts +++ b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/annotation.ts @@ -1,3 +1,18 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + /// <reference path="../graph.ts" /> /// <reference path="../render.ts" /> /// <reference path="scene.ts" /> diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/edge.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/edge.ts index e11ec97f807..d84bb8e2ca9 100644 --- a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/edge.ts +++ b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/edge.ts @@ -1,3 +1,18 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + /// <reference path="../graph.ts" /> /// <reference path="../render.ts" /> /// <reference path="scene.ts" /> diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/minimap.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/minimap.ts index 1a341327658..c3cf3d684ba 100644 --- a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/minimap.ts +++ b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/minimap.ts @@ -1,3 +1,18 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + /// <reference path="../../../../typings/tsd.d.ts" /> /// <reference path="../common.ts" /> @@ -13,6 +28,8 @@ export class Minimap { private canvas: HTMLCanvasElement; /** A buffer canvas used for temporary drawing to avoid flickering. */ private canvasBuffer: HTMLCanvasElement; + private download: HTMLLinkElement; + private downloadCanvas: HTMLCanvasElement; /** The minimap svg used for holding the viewpoint rectangle. */ private minimapSvg: SVGSVGElement; @@ -91,12 +108,16 @@ export class Minimap { this.viewpointCoord.y = clickCoords[1] - height / 2; this.updateViewpoint(); }); - this.viewpoint = <SVGRectElement> $viewpoint.node(); - this.minimapSvg = <SVGSVGElement> $minimapSvg.node(); + this.viewpoint = <SVGRectElement>$viewpoint.node(); + this.minimapSvg = <SVGSVGElement>$minimapSvg.node(); this.minimap = minimap; - this.canvas = <HTMLCanvasElement> $minimap.select("canvas.first").node(); + this.canvas = <HTMLCanvasElement>$minimap.select("canvas.first").node(); this.canvasBuffer = - <HTMLCanvasElement> $minimap.select("canvas.second").node(); + <HTMLCanvasElement>$minimap.select("canvas.second").node(); + this.downloadCanvas = + <HTMLCanvasElement>$minimap.select("canvas.download").node(); + d3.select(this.downloadCanvas).style("display", "none"); + } /** @@ -121,6 +142,12 @@ export class Minimap { * was updated (e.g. when a node was expanded). */ update(): void { + let $download = d3.select("#graphdownload"); + this.download = <HTMLLinkElement>$download.node(); + $download.on("click", d => { + this.download.href = this.downloadCanvas.toDataURL("image/png"); + }); + let $svg = d3.select(this.svg); // Read all the style rules in the document and embed them into the svg. // The svg needs to be self contained, i.e. all the style rules need to be @@ -156,7 +183,8 @@ export class Minimap { // Get the size of the entire scene. let sceneSize = this.zoomG.getBBox(); // Since we add padding, account for that here. - sceneSize.height += this.labelPadding; + sceneSize.height += this.labelPadding * 2; + sceneSize.width += this.labelPadding * 2; // Temporarily assign an explicit width/height to the main svg, since // it doesn't have one (uses flex-box), but we need it for the canvas @@ -182,6 +210,14 @@ export class Minimap { // viewpoint rect. d3.select(this.minimapSvg).attr(<any>this.minimapSize); d3.select(this.canvasBuffer).attr(<any>this.minimapSize); + + // Download canvas width and height are multiples of the style width and + // height in order to increase pixel density of the PNG for clarity. + d3.select(this.downloadCanvas).style( + <any>{ width: sceneSize.width, height: sceneSize.height }); + d3.select(this.downloadCanvas).attr( + <any>{ width: sceneSize.width * 3, height: sceneSize.height * 3 }); + if (this.translate != null && this.zoom != null) { // Update the viewpoint rectangle shape since the aspect ratio of the // map has changed. @@ -216,6 +252,11 @@ export class Minimap { // Swap the two canvases. [this.canvas, this.canvasBuffer] = [this.canvasBuffer, this.canvas]; }); + let downloadContext = this.downloadCanvas.getContext("2d"); + downloadContext.clearRect(0, 0, this.downloadCanvas.width, + this.downloadCanvas.height); + downloadContext.drawImage(image, 0, 0, + this.downloadCanvas.width, this.downloadCanvas.height); }; image.src = "data:image/svg+xml;base64," + btoa(svgXml); } diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/node.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/node.ts index 37d409f41a4..57496ceffb6 100644 --- a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/node.ts +++ b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/node.ts @@ -1,3 +1,18 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + /// <reference path="../graph.ts" /> /// <reference path="scene.ts" /> /// <reference path="annotation.ts" /> diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/scene.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/scene.ts index 2e2467f0392..6e97e904daa 100644 --- a/tensorflow/tensorboard/components/tf-graph-common/lib/scene/scene.ts +++ b/tensorflow/tensorboard/components/tf-graph-common/lib/scene/scene.ts @@ -1,3 +1,18 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + /// <reference path="../graph.ts" /> /// <reference path="edge.ts" /> /// <reference path="node.ts" /> diff --git a/tensorflow/tensorboard/components/tf-graph-common/lib/template.ts b/tensorflow/tensorboard/components/tf-graph-common/lib/template.ts index 2ec0f6d5e96..41fbbbb9ff6 100644 --- a/tensorflow/tensorboard/components/tf-graph-common/lib/template.ts +++ b/tensorflow/tensorboard/components/tf-graph-common/lib/template.ts @@ -1,3 +1,18 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + /// <reference path="graph.ts" /> /// <reference path="hierarchy.ts" /> diff --git a/tensorflow/tensorboard/components/tf-graph-common/test/graph-test.ts b/tensorflow/tensorboard/components/tf-graph-common/test/graph-test.ts index 25ca079d06b..a17e7d4b8eb 100644 --- a/tensorflow/tensorboard/components/tf-graph-common/test/graph-test.ts +++ b/tensorflow/tensorboard/components/tf-graph-common/test/graph-test.ts @@ -1,3 +1,18 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + suite("graph", () => { let assert = chai.assert; diff --git a/tensorflow/tensorboard/components/tf-graph-common/test/hierarchy-test.ts b/tensorflow/tensorboard/components/tf-graph-common/test/hierarchy-test.ts index 25ca079d06b..a17e7d4b8eb 100644 --- a/tensorflow/tensorboard/components/tf-graph-common/test/hierarchy-test.ts +++ b/tensorflow/tensorboard/components/tf-graph-common/test/hierarchy-test.ts @@ -1,3 +1,18 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + suite("graph", () => { let assert = chai.assert; diff --git a/tensorflow/tensorboard/components/tf-graph-common/test/layout-test.ts b/tensorflow/tensorboard/components/tf-graph-common/test/layout-test.ts index 59776400387..a6f8fef2f68 100644 --- a/tensorflow/tensorboard/components/tf-graph-common/test/layout-test.ts +++ b/tensorflow/tensorboard/components/tf-graph-common/test/layout-test.ts @@ -1,3 +1,18 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + suite("layout", () => { let assert = chai.assert; diff --git a/tensorflow/tensorboard/components/tf-graph-common/test/parser-test.ts b/tensorflow/tensorboard/components/tf-graph-common/test/parser-test.ts index cbbb58ddc6c..6bea2f4eaa8 100644 --- a/tensorflow/tensorboard/components/tf-graph-common/test/parser-test.ts +++ b/tensorflow/tensorboard/components/tf-graph-common/test/parser-test.ts @@ -1,3 +1,18 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + suite("parser", () => { let assert = chai.assert; diff --git a/tensorflow/tensorboard/components/tf-graph-loader/test/loader.ts b/tensorflow/tensorboard/components/tf-graph-loader/test/loader.ts index b03413a1211..c59955dad77 100644 --- a/tensorflow/tensorboard/components/tf-graph-loader/test/loader.ts +++ b/tensorflow/tensorboard/components/tf-graph-loader/test/loader.ts @@ -1,3 +1,18 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + suite("graph loader", () => { let assert = chai.assert; diff --git a/tensorflow/tensorboard/components/tf-graph/tf-graph-controls.html b/tensorflow/tensorboard/components/tf-graph/tf-graph-controls.html index 49a8d58e4d1..058e10e5e53 100644 --- a/tensorflow/tensorboard/components/tf-graph/tf-graph-controls.html +++ b/tensorflow/tensorboard/components/tf-graph/tf-graph-controls.html @@ -61,10 +61,6 @@ table td { padding-bottom: 10px; } -#fit { - color: var(--paper-orange-500); -} - paper-radio-button { padding: 5px; } @@ -132,7 +128,7 @@ svg.icon { padding: 0 0 0 55px; } -.fit-button-text { +.button-text { text-transform: none; padding: 8px 18px 0 18px; font-size: 14px @@ -145,10 +141,11 @@ svg.icon { margin-top: 4px; } -.fit-button { +.iconbutton { padding: 2px; width: 30px; height: 30px; + color: var(--paper-orange-500); } .hidden-input { @@ -162,21 +159,21 @@ svg.icon { clear: both; } </style> -<svg style="display:none;"> - <defs> - <!-- Summary icon. --> - <svg id="summary-icon" fill="#848484" height="12" viewBox="0 0 24 24" width="12"> - <path d="M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z" /> - </svg> - </defs> -</svg> <div class="allcontrols"> <div class="control-holder"> - <paper-icon-button id="fit" icon="aspect-ratio" class="fit-button" on-click="fit" alt="Fit to screen"> + <paper-icon-button icon="aspect-ratio" class="iconbutton" on-click="fit" alt="Fit to screen"> </paper-icon-button> - <paper-button class="fit-button-text" on-click="fit">Fit to screen + <paper-button class="button-text" on-click="fit">Fit to screen </paper-button> </div> + <div class="control-holder"> + <paper-icon-button icon="file-download" class="iconbutton" on-click="download" alt="Download PNG"> + </paper-icon-button> + <paper-button class="button-text" on-click="download">Download PNG + </paper-button> + <a href="#" id="graphdownload" class="title" download="graph.png"> + </a> + </div> <div class="control-holder"> <div class="title">Run</div> <paper-dropdown-menu no-label-float no-animations noink class="run-dropdown"> @@ -383,6 +380,7 @@ Polymer({ type: Number, notify: true, value: 0, + observer: '_selectedDatasetChanged' }, selectedFile: { type: Object, @@ -434,17 +432,44 @@ Polymer({ endColor: params.endColor }; }, + download: function() { + this.$.graphdownload.click(); + }, _updateFileInput: function(e) { + var file = e.target.files[0]; + if (!file) { + return; + } + this._setDownloadFilename(file.name); this.set('selectedFile', e); }, _datasetsChanged: function(newDatasets, oldDatasets) { if (oldDatasets != null || this.selected == null) { // Select the first dataset by default. this.set('selectedDataset', 0); + this._setDownloadFilename(this.datasets[this.selectedDataset].path); + } + }, + _selectedDatasetChanged: function(newDataset, oldDataset) { + if (this.datasets) { + this._setDownloadFilename(this.datasets[newDataset].path); } }, _getFile: function() { this.$.file.click(); + }, + _setDownloadFilename: function(graphPath) { + // Strip off everything before the last "/" and strip off the file + // extension in order to get the name of the PNG for the graph. + var dotIndex = graphPath.lastIndexOf('.'); + if (dotIndex) { + graphPath = graphPath.substring(0, dotIndex); + } + var slashIndex = graphPath.lastIndexOf('/'); + if (slashIndex) { + graphPath = graphPath.substring(slashIndex + 1); + } + this.$.graphdownload.setAttribute('download', graphPath + '.png'); } }); diff --git a/tensorflow/tensorboard/components/tf-graph/tf-graph-minimap.html b/tensorflow/tensorboard/components/tf-graph/tf-graph-minimap.html index 2b6beeadedf..71deacba5e1 100644 --- a/tensorflow/tensorboard/components/tf-graph/tf-graph-minimap.html +++ b/tensorflow/tensorboard/components/tf-graph/tf-graph-minimap.html @@ -45,6 +45,7 @@ svg { <canvas class="first"></canvas> <!-- Additional canvas to use as buffer to avoid flickering between updates --> <canvas class="second"></canvas> +<canvas class="download"></canvas> </template> <script> Polymer({ diff --git a/tensorflow/tensorboard/components/tf-graph/tf-graph-scene.html b/tensorflow/tensorboard/components/tf-graph/tf-graph-scene.html index 51ea6497188..6b7512d4e8d 100644 --- a/tensorflow/tensorboard/components/tf-graph/tf-graph-scene.html +++ b/tensorflow/tensorboard/components/tf-graph/tf-graph-scene.html @@ -75,6 +75,9 @@ <use xlink:href="#op-node-annotation-stamp" x="7" y="2" /> <use xlink:href="#op-node-annotation-stamp" x="5" y="2" /> </g> + <svg id="summary-icon" fill="#848484" height="12" viewBox="0 0 24 24" width="12"> + <path d="M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z" /> + </svg> <!-- Where the linearGradient for each node is stored. Used when coloring by proportions of devices. diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index e1f279f20cb..150e5321abc 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -37,7 +37,7 @@ def if_cuda(a, b=[]): def tf_copts(): - return ["-pthread", "-fno-exceptions",] + if_cuda(["-DGOOGLE_CUDA=1"]) + return ["-pthread", "-fno-exceptions", "-DEIGEN_AVOID_STL_ARRAY",] + if_cuda(["-DGOOGLE_CUDA=1"]) # Given a list of "op_lib_names" (a list of files in the ops directory @@ -149,6 +149,7 @@ def tf_gen_op_wrapper_py(name, out=None, hidden=[], visibility=None, deps=[], # Make a py_library out of the generated python file. native.py_library(name=name, srcs=[out], + srcs_version="PY2AND3", visibility=visibility, deps=[ "//tensorflow/core:protos_all_py", @@ -309,6 +310,7 @@ def tf_py_wrap_cc(name, srcs, swig_includes=[], deps=[], copts=[], **kwargs): deps=deps) native.py_library(name=name, srcs=[":" + name + ".py"], + srcs_version="PY2AND3", data=[":" + cc_library_name]) diff --git a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc index cdbd1823ea1..a67b0390005 100755 --- a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc +++ b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc @@ -1,4 +1,18 @@ #!/usr/bin/env python2 +# Copyright 2015 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== """Crosstool wrapper for compiling CUDA programs.