TensorFlow: Upstream changes to git
Changes: - Updates to docs - Several changes for Python 3 compatibility - Added license headers Base CL: 108710566
@ -144,8 +144,8 @@ tf_cuda_library(
|
||||
name = "gpu_runtime",
|
||||
srcs = glob(
|
||||
[
|
||||
"common_runtime/gpu/**/*.h",
|
||||
"common_runtime/gpu/**/*.cc",
|
||||
"common_runtime/gpu/*.h",
|
||||
"common_runtime/gpu/*.cc",
|
||||
],
|
||||
exclude = [
|
||||
"**/*main.cc",
|
||||
@ -628,6 +628,7 @@ filegroup(
|
||||
"//tensorflow/core:kernels/relu_op.h",
|
||||
"//tensorflow/core:kernels/softplus_op.cc",
|
||||
"//tensorflow/core:kernels/softplus_op.h",
|
||||
"//tensorflow/core:kernels/stack_ops.cc",
|
||||
"//tensorflow/core:kernels/transpose_op.cc",
|
||||
"//tensorflow/core:kernels/transpose_op.h",
|
||||
"//tensorflow/core:kernels/transpose_op_functor.h",
|
||||
@ -673,6 +674,7 @@ cc_library(
|
||||
copts = [
|
||||
"-mfpu=neon",
|
||||
"-std=c++11",
|
||||
"-O2",
|
||||
],
|
||||
tags = [
|
||||
"manual",
|
||||
|
@ -60,6 +60,7 @@ const OpDef* OpRegistry::LookUp(const string& op_type_name,
|
||||
if (op_def == nullptr) {
|
||||
status->Update(
|
||||
errors::NotFound("Op type not registered '", op_type_name, "'"));
|
||||
LOG(INFO) << status->ToString();
|
||||
static bool first_unregistered = true;
|
||||
if (first_unregistered) {
|
||||
OpList op_list;
|
||||
|
@ -817,6 +817,11 @@ class OpKernelContext {
|
||||
return output_allocation_types_[index];
|
||||
}
|
||||
|
||||
// Per-step resource manager for use by white-listed internal ops.
|
||||
ResourceMgr* step_resource_manager() const {
|
||||
return params_.step_resource_manager;
|
||||
}
|
||||
|
||||
private:
|
||||
Allocator* get_allocator(AllocatorAttributes attr) {
|
||||
Allocator* allocator = params_.device->GetAllocator(attr);
|
||||
@ -836,13 +841,6 @@ class OpKernelContext {
|
||||
}
|
||||
}
|
||||
|
||||
// Per-step resource manager for use by white-listed internal ops.
|
||||
friend class TemporaryVariableOp;
|
||||
friend class DestroyTemporaryVariableOp;
|
||||
ResourceMgr* step_resource_manager() const {
|
||||
return params_.step_resource_manager;
|
||||
}
|
||||
|
||||
// Internal common method used when allocating tensor memory
|
||||
Status allocate_tensor(DataType type, const TensorShape& shape,
|
||||
Tensor* out_tensor, AllocatorAttributes attr);
|
||||
|
@ -140,7 +140,7 @@ TEST(TensorUtil, Concat) {
|
||||
std::vector<Tensor> to_concat;
|
||||
int64 total_size = 0;
|
||||
int offset = 0;
|
||||
for (int entry = 0; entry < sizes.size(); ++entry) {
|
||||
for (size_t entry = 0; entry < sizes.size(); ++entry) {
|
||||
const int64 size = sizes[entry];
|
||||
Tensor tensor(DT_INT32, TensorShape({size, 2}));
|
||||
for (int i = offset; i < offset + size; ++i) {
|
||||
@ -175,7 +175,7 @@ TEST(TensorUtil, Split) {
|
||||
ASSERT_EQ(sizes.size(), splits.size());
|
||||
|
||||
int offset = 0;
|
||||
for (int entry = 0; entry < splits.size(); ++entry) {
|
||||
for (size_t entry = 0; entry < splits.size(); ++entry) {
|
||||
const int64 size = sizes[entry];
|
||||
const Tensor& split = splits[entry];
|
||||
|
||||
|
@ -1011,6 +1011,10 @@ Status Partition(const PartitionOptions& opts, Graph* g,
|
||||
|
||||
if (!edge->IsControlEdge() &&
|
||||
IsRefType(src->output_type(edge->src_output()))) {
|
||||
AddNodeAttr("_start_time", recv_start_time, recv);
|
||||
if (real_recv != recv) {
|
||||
AddNodeAttr("_start_time", recv_start_time, real_recv);
|
||||
}
|
||||
// If src is of ref type and the edge is not a control edge, dst has
|
||||
// read semantics and therefore we must control the recv.
|
||||
ref_recvs.push_back(real_recv);
|
||||
|
@ -37,7 +37,7 @@ __global__ void BiasOpCustomKernel(int nthreads, const T* input, const T* bias,
|
||||
T* output) {
|
||||
CUDA_1D_KERNEL_LOOP(index, nthreads) {
|
||||
int bias_offset = index % bias_size;
|
||||
output[index] = __ldg(input + index) + __ldg(bias + bias_offset);
|
||||
output[index] = ldg(input + index) + ldg(bias + bias_offset);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -42,9 +42,10 @@ struct scalar_const_op {
|
||||
return *val;
|
||||
}
|
||||
|
||||
template <typename Index>
|
||||
EIGEN_STRONG_INLINE const Packet packetOp(Index, Index = 0) const {
|
||||
return internal::pset1<Packet>(*val);
|
||||
template <typename Index, typename PacketType = Packet>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const PacketType
|
||||
packetOp(Index, Index = 0) const {
|
||||
return internal::pset1<PacketType>(*val);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -383,12 +383,53 @@ class Conv2DCustomBackpropInputOp : public OpKernel {
|
||||
// The output image size is the spatial size of the output.
|
||||
const int output_image_size = out_rows * out_cols;
|
||||
|
||||
// TODO(andydavis) Get L2/L3 cache sizes from device.
|
||||
const size_t l2_cache_size = 256LL << 10;
|
||||
const size_t l3_cache_size = 30LL << 20;
|
||||
|
||||
// Use L3 cache size as target working set size.
|
||||
const size_t target_working_set_size = l3_cache_size / sizeof(T);
|
||||
|
||||
// Calculate size of matrices involved in MatMul: C = A x B.
|
||||
const size_t size_A = output_image_size * out_depth;
|
||||
|
||||
const size_t size_B = filter_total_size * out_depth;
|
||||
|
||||
const size_t size_C = output_image_size * filter_total_size;
|
||||
|
||||
const size_t work_unit_size = size_A + size_B + size_C;
|
||||
|
||||
auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
|
||||
|
||||
// Calculate per-thread work unit size.
|
||||
const size_t thread_work_unit_size =
|
||||
work_unit_size / worker_threads.num_threads;
|
||||
|
||||
// Set minimum per-thread work unit size to size of L2 cache.
|
||||
const size_t min_thread_work_unit_size = l2_cache_size / sizeof(T);
|
||||
|
||||
// Use parallel tensor contractions if there is no batching, or if the
|
||||
// minimum per-thread work unit size threshold has been exceeded.
|
||||
// Otherwise, revert to multiple single-threaded matmul ops running in
|
||||
// parallel to keep all threads busy.
|
||||
// TODO(andydavis) Explore alternatives to branching the code in this way
|
||||
// (i.e. run multiple, parallel tensor contractions in another thread pool).
|
||||
const bool use_parallel_contraction =
|
||||
batch == 1 || thread_work_unit_size >= min_thread_work_unit_size;
|
||||
|
||||
const size_t shard_size =
|
||||
use_parallel_contraction
|
||||
? 1
|
||||
: (target_working_set_size + work_unit_size - 1) / work_unit_size;
|
||||
|
||||
Tensor col_buffer;
|
||||
OP_REQUIRES_OK(
|
||||
context,
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_temp(
|
||||
DataTypeToEnum<T>::value,
|
||||
TensorShape({output_image_size, filter_total_size}), &col_buffer));
|
||||
TensorShape({static_cast<int64>(shard_size),
|
||||
static_cast<int64>(output_image_size),
|
||||
static_cast<int64>(filter_total_size)}),
|
||||
&col_buffer));
|
||||
|
||||
// The input offset corresponding to a single input image.
|
||||
const int input_offset = input_rows * input_cols * in_depth;
|
||||
@ -400,6 +441,7 @@ class Conv2DCustomBackpropInputOp : public OpKernel {
|
||||
auto* out_backprop_data = out_backprop.template flat<T>().data();
|
||||
auto* input_backprop_data = in_backprop->template flat<T>().data();
|
||||
|
||||
if (use_parallel_contraction) {
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>,
|
||||
Eigen::Unaligned> TensorMap;
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>,
|
||||
@ -420,12 +462,54 @@ class Conv2DCustomBackpropInputOp : public OpKernel {
|
||||
|
||||
C.device(context->eigen_cpu_device()) = A.contract(B, contract_dims);
|
||||
|
||||
Col2im<T>(col_buffer_data, in_depth, input_rows, input_cols, filter_rows,
|
||||
filter_cols, pad_top, pad_left, pad_bottom, pad_right, stride,
|
||||
stride, input_backprop_data);
|
||||
Col2im<T>(col_buffer_data, in_depth, input_rows, input_cols,
|
||||
filter_rows, filter_cols, pad_top, pad_left, pad_bottom,
|
||||
pad_right, stride, stride, input_backprop_data);
|
||||
|
||||
input_backprop_data += input_offset;
|
||||
}
|
||||
} else {
|
||||
typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic,
|
||||
Eigen::RowMajor>> MatrixMap;
|
||||
typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic,
|
||||
Eigen::RowMajor>> ConstMatrixMap;
|
||||
|
||||
for (int image_id = 0; image_id < batch; image_id += shard_size) {
|
||||
const int shard_limit = std::min(static_cast<int>(shard_size),
|
||||
static_cast<int>(batch) - image_id);
|
||||
|
||||
auto shard = [&in_depth, &input_rows, &input_cols, &filter_rows,
|
||||
&filter_cols, &pad_top, &pad_left, &pad_bottom,
|
||||
&pad_right, &stride, &output_image_size,
|
||||
&filter_total_size, &out_depth, &input_backprop_data,
|
||||
&col_buffer_data, &out_backprop_data, &filter_data,
|
||||
&input_offset, &output_offset,
|
||||
&size_C](int64 start, int64 limit) {
|
||||
for (int shard_id = start; shard_id < limit; ++shard_id) {
|
||||
T* im2col_buf = col_buffer_data + shard_id * size_C;
|
||||
T* input_data = input_backprop_data + shard_id * input_offset;
|
||||
const T* out_data = out_backprop_data + shard_id * output_offset;
|
||||
|
||||
// Compute gradient into 'im2col_buf'.
|
||||
MatrixMap C(im2col_buf, output_image_size, filter_total_size);
|
||||
|
||||
ConstMatrixMap A(out_data, output_image_size, out_depth);
|
||||
ConstMatrixMap B(filter_data, filter_total_size, out_depth);
|
||||
|
||||
C.noalias() = A * B.transpose();
|
||||
|
||||
Col2im<T>(im2col_buf, in_depth, input_rows, input_cols, filter_rows,
|
||||
filter_cols, pad_top, pad_left, pad_bottom, pad_right,
|
||||
stride, stride, input_data);
|
||||
}
|
||||
};
|
||||
Shard(worker_threads.num_threads, worker_threads.workers, shard_limit,
|
||||
work_unit_size, shard);
|
||||
|
||||
input_backprop_data += input_offset * shard_limit;
|
||||
out_backprop_data += output_offset * shard_limit;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
@ -620,8 +704,8 @@ class Conv2DCustomBackpropFilterOp : public OpKernel {
|
||||
&pad_left, &pad_bottom, &pad_right, &stride, &input_offset,
|
||||
&size_A](int64 start, int64 limit) {
|
||||
for (int shard_id = start; shard_id < limit; ++shard_id) {
|
||||
auto input_data_shard = input_data + shard_id * input_offset;
|
||||
auto col_data_shard = col_buffer_data + shard_id * size_A;
|
||||
const T* input_data_shard = input_data + shard_id * input_offset;
|
||||
T* col_data_shard = col_buffer_data + shard_id * size_A;
|
||||
|
||||
// When we compute the gradient with respect to the filters, we need
|
||||
// to do im2col to allow gemm-type computation.
|
||||
|
@ -30,7 +30,7 @@ typedef Eigen::GpuDevice GPUDevice;
|
||||
namespace functor {
|
||||
|
||||
// A simple array that contains data that can be passed between CPU and GPU.
|
||||
template <typename T, int IndexCount>
|
||||
template <typename T, int IndexCount, T DefaultValue>
|
||||
struct Array {
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T& operator[](int index) const {
|
||||
return data[index];
|
||||
@ -38,16 +38,57 @@ struct Array {
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T& operator[](int index) {
|
||||
return data[index];
|
||||
}
|
||||
int data[IndexCount];
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Array() {
|
||||
for (int i = 0; i < IndexCount; i++) {
|
||||
data[i] = DefaultValue;
|
||||
}
|
||||
}
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Array(T a0) {
|
||||
data[0] = a0;
|
||||
for (int i = 1; i < IndexCount; i++) {
|
||||
data[i] = DefaultValue;
|
||||
}
|
||||
}
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Array(T a0, T a1) {
|
||||
data[0] = a0;
|
||||
data[1] = a1;
|
||||
for (int i = 2; i < IndexCount; i++) {
|
||||
data[i] = DefaultValue;
|
||||
}
|
||||
}
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Array(T a0, T a1, T a2) {
|
||||
data[0] = a0;
|
||||
data[1] = a1;
|
||||
data[2] = a2;
|
||||
for (int i = 3; i < IndexCount; i++) {
|
||||
data[i] = DefaultValue;
|
||||
}
|
||||
}
|
||||
T data[IndexCount];
|
||||
};
|
||||
|
||||
// A dimension type with compile-time known size.
|
||||
template <int IndexCount>
|
||||
struct Dimension : Array<int, IndexCount> {};
|
||||
struct Dimension : Array<int, IndexCount, 1> {
|
||||
typedef Array<int, IndexCount, 1> Base;
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Dimension() : Base() {}
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Dimension(int a0) : Base(a0) {}
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Dimension(int a0, int a1)
|
||||
: Base(a0, a1) {}
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Dimension(int a0, int a1, int a2)
|
||||
: Base(a0, a1, a2) {}
|
||||
};
|
||||
|
||||
// An index type with compile-time known size.
|
||||
template <int IndexCount>
|
||||
struct Index : Array<int, IndexCount> {};
|
||||
struct Index : Array<int, IndexCount, 0> {
|
||||
typedef Array<int, IndexCount, 0> Base;
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index() : Base() {}
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index(int a0) : Base(a0) {}
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index(int a0, int a1) : Base(a0, a1) {}
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index(int a0, int a1, int a2)
|
||||
: Base(a0, a1, a2) {}
|
||||
};
|
||||
|
||||
// A helper function that converts a tensor index into a flat array index.
|
||||
template <int IndexCount>
|
||||
@ -94,7 +135,7 @@ __global__ void SwapDimension0And2InTensor3(int nthreads, const T* input,
|
||||
|
||||
int input_index = TensorIndexToFlat(input_tensor_index, input_dims);
|
||||
|
||||
output[output_index] = __ldg(input + input_index);
|
||||
output[output_index] = ldg(input + input_index);
|
||||
}
|
||||
}
|
||||
|
||||
@ -119,7 +160,88 @@ __global__ void SwapDimension1And2InTensor3(int nthreads, const T* input,
|
||||
|
||||
int input_index = TensorIndexToFlat(input_tensor_index, input_dims);
|
||||
|
||||
output[output_index] = __ldg(input + input_index);
|
||||
output[output_index] = ldg(input + input_index);
|
||||
}
|
||||
}
|
||||
|
||||
// Use shared memory tiles to swap dimension-1 and dimension-2 of a 3D tensor,
|
||||
// where dimensions are zero-based: output[i][j][k] = input[i][k][j].
|
||||
// TileSize could be arbitrary. But for best performance, it is better to be
|
||||
// the same as number of threads in a warp, which is 32 for almost all GPU
|
||||
// architectures.
|
||||
template <typename T, int TileSize>
|
||||
__global__ void SwapDimension1And2InTensor3UsingTiles(const T* input,
|
||||
Dimension<3> input_dims,
|
||||
T* output) {
|
||||
// One extra line in the inner dimension to avoid share memory bank conflict.
|
||||
__shared__ T shared_memory_tile[TileSize][TileSize + 1];
|
||||
|
||||
int x = threadIdx.x;
|
||||
if (x >= TileSize) {
|
||||
return;
|
||||
}
|
||||
|
||||
Dimension<3> output_dims = {
|
||||
input_dims[0], input_dims[2], input_dims[1],
|
||||
};
|
||||
|
||||
Dimension<3> input_dims_in_tiles = {
|
||||
input_dims[0], (input_dims[1] + TileSize - 1) / TileSize,
|
||||
(input_dims[2] + TileSize - 1) / TileSize,
|
||||
};
|
||||
|
||||
Index<3> input_tile_index =
|
||||
FlatToTensorIndex(blockIdx.x, input_dims_in_tiles);
|
||||
|
||||
Index<3> input_tile_origin = {
|
||||
input_tile_index[0], input_tile_index[1] * TileSize,
|
||||
input_tile_index[2] * TileSize,
|
||||
};
|
||||
|
||||
int input_origin_flat_index =
|
||||
TensorIndexToFlat(input_tile_origin, input_dims);
|
||||
|
||||
int tile_width = TileSize;
|
||||
// Only the last row or column may not have the full size.
|
||||
if (input_tile_index[2] == input_dims_in_tiles[2] - 1) {
|
||||
tile_width = input_dims[2] - (input_dims_in_tiles[2] - 1) * TileSize;
|
||||
}
|
||||
int tile_height = TileSize;
|
||||
if (input_tile_index[1] == input_dims_in_tiles[1] - 1) {
|
||||
tile_height = input_dims[1] - (input_dims_in_tiles[1] - 1) * TileSize;
|
||||
}
|
||||
|
||||
// Load the data from input memory to the shared memory tile.
|
||||
if (x < tile_width) {
|
||||
int input_flat_index = input_origin_flat_index + x;
|
||||
for (int y = 0; y < tile_height; y++) {
|
||||
shared_memory_tile[y][x] = input[input_flat_index];
|
||||
input_flat_index += input_dims[2];
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
Index<3> output_tile_index = {
|
||||
input_tile_index[0], input_tile_index[2], input_tile_index[1],
|
||||
};
|
||||
|
||||
Index<3> output_tile_origin = {
|
||||
output_tile_index[0], output_tile_index[1] * TileSize,
|
||||
output_tile_index[2] * TileSize,
|
||||
};
|
||||
|
||||
int output_origin_flat_index =
|
||||
TensorIndexToFlat(output_tile_origin, output_dims);
|
||||
|
||||
int output_flat_index = output_origin_flat_index + x;
|
||||
|
||||
// Load the data from the shared memory tile to the output memory.
|
||||
if (x < tile_height) {
|
||||
for (int y = 0; y < tile_width; y++) {
|
||||
output[output_flat_index] = shared_memory_tile[x][y];
|
||||
output_flat_index += output_dims[2];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -212,6 +334,39 @@ struct PadInput<GPUDevice, T, int> {
|
||||
}
|
||||
};
|
||||
|
||||
// Launch the GPU kernel that would swap dimension-1 and dimension-2 in a
|
||||
// 3D tensor. It looks at the shape of the incoming data, and decides the best
|
||||
// strategy to launch.
|
||||
template <typename T>
|
||||
void RunSwapDimension1And2InTensor3(const GPUDevice& d, const T* input,
|
||||
const Dimension<3>& input_dims, T* output) {
|
||||
// If both dimensions are not trivial, use tiles for the actual swapping.
|
||||
// Otherwise, the trivial swapping relying on the ldg cache is more efficient.
|
||||
static const int kMinDimensionToUseTiles = 16;
|
||||
bool use_tiles = (input_dims[1] >= kMinDimensionToUseTiles &&
|
||||
input_dims[2] >= kMinDimensionToUseTiles);
|
||||
if (use_tiles) {
|
||||
// The tile-size can be chosen to be arbitrary number. But it is better to
|
||||
// be the same as number of threads in a warp, which is 32.
|
||||
static const int TileSize = 32;
|
||||
Dimension<3> input_dims_in_tiles = {
|
||||
input_dims[0], (input_dims[1] + TileSize - 1) / TileSize,
|
||||
(input_dims[2] + TileSize - 1) / TileSize,
|
||||
};
|
||||
int total_tiles_count = input_dims_in_tiles[0] * input_dims_in_tiles[1] *
|
||||
input_dims_in_tiles[2];
|
||||
SwapDimension1And2InTensor3UsingTiles<
|
||||
T, TileSize><<<total_tiles_count, TileSize, 0, d.stream()>>>(
|
||||
input, input_dims, output);
|
||||
} else {
|
||||
int total_element_count = input_dims[0] * input_dims[1] * input_dims[2];
|
||||
CudaLaunchConfig config = GetCudaLaunchConfig(total_element_count, d);
|
||||
SwapDimension1And2InTensor3<
|
||||
T><<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
||||
config.virtual_thread_count, input, input_dims, output);
|
||||
}
|
||||
}
|
||||
|
||||
// A GPU helper functor that converts NHWC TensorFlow data format to
|
||||
// NCHW format that is accepted by Cudnn.
|
||||
template <typename T>
|
||||
@ -223,10 +378,7 @@ struct NHWCToNCHW<GPUDevice, T> {
|
||||
combined_dims[0] = in.dimension(0);
|
||||
combined_dims[1] = in.dimension(1) * in.dimension(2);
|
||||
combined_dims[2] = in.dimension(3);
|
||||
CudaLaunchConfig config = GetCudaLaunchConfig(out.size(), d);
|
||||
SwapDimension1And2InTensor3<
|
||||
T><<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
||||
config.virtual_thread_count, in.data(), combined_dims, out.data());
|
||||
RunSwapDimension1And2InTensor3(d, in.data(), combined_dims, out.data());
|
||||
}
|
||||
};
|
||||
|
||||
@ -241,10 +393,7 @@ struct NCHWToNHWC<GPUDevice, T> {
|
||||
combined_dims[0] = in.dimension(0);
|
||||
combined_dims[1] = in.dimension(1);
|
||||
combined_dims[2] = in.dimension(2) * in.dimension(3);
|
||||
CudaLaunchConfig config = GetCudaLaunchConfig(out.size(), d);
|
||||
SwapDimension1And2InTensor3<
|
||||
T><<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
|
||||
config.virtual_thread_count, in.data(), combined_dims, out.data());
|
||||
RunSwapDimension1And2InTensor3(d, in.data(), combined_dims, out.data());
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -66,7 +66,7 @@ class LRNOp : public OpKernel {
|
||||
context->allocate_output(
|
||||
0, TensorShape({batch, rows, cols, depth}), &output));
|
||||
|
||||
#if !defined(__ANDROID__)
|
||||
#if defined(__ANDROID__)
|
||||
MognetLRN(in, batch, rows, cols, depth, output);
|
||||
#else
|
||||
const int nodes = cols * rows;
|
||||
|
@ -49,13 +49,16 @@ class ResizeNearestNeighborOp : public OpKernel {
|
||||
errors::InvalidArgument("shape_t must have two elements",
|
||||
shape_t.shape().ShortDebugString()));
|
||||
|
||||
auto Svec = shape_t.vec<int32>();
|
||||
auto sizes = shape_t.vec<int32>();
|
||||
OP_REQUIRES(context, sizes(0) > 0 && sizes(1) > 0,
|
||||
errors::InvalidArgument("shape_t's elements must be positive"));
|
||||
|
||||
// Initialize shape to the batch size of the input, then add
|
||||
// the rest of the dimensions
|
||||
Tensor* output = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(
|
||||
0, TensorShape({input.dim_size(0), Svec(0),
|
||||
Svec(1), input.dim_size(3)}),
|
||||
0, TensorShape({input.dim_size(0), sizes(0),
|
||||
sizes(1), input.dim_size(3)}),
|
||||
&output));
|
||||
|
||||
const int64 batch_size = input.dim_size(0);
|
||||
@ -87,12 +90,93 @@ class ResizeNearestNeighborOp : public OpKernel {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
class ResizeNearestNeighborOpGrad : public OpKernel {
|
||||
public:
|
||||
explicit ResizeNearestNeighborOpGrad(OpKernelConstruction* context)
|
||||
: OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
// Grab and validate the input:
|
||||
const Tensor& input = context->input(0);
|
||||
OP_REQUIRES(context, input.dims() == 4,
|
||||
errors::InvalidArgument("input must be 4-dimensional",
|
||||
input.shape().ShortDebugString()));
|
||||
|
||||
// Grab and validate the output shape:
|
||||
const Tensor& shape_t = context->input(1);
|
||||
OP_REQUIRES(context, shape_t.dims() == 1,
|
||||
errors::InvalidArgument("shape_t must be 1-dimensional",
|
||||
shape_t.shape().ShortDebugString()));
|
||||
OP_REQUIRES(context, shape_t.NumElements() == 2,
|
||||
errors::InvalidArgument("shape_t must have two elements",
|
||||
shape_t.shape().ShortDebugString()));
|
||||
|
||||
auto sizes = shape_t.vec<int32>();
|
||||
OP_REQUIRES(context, sizes(0) > 0 && sizes(1) > 0,
|
||||
errors::InvalidArgument("shape_t's elements must be positive"));
|
||||
|
||||
// Initialize shape to the batch size of the input, then add
|
||||
// the rest of the dimensions
|
||||
Tensor* output = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(
|
||||
0, TensorShape({input.dim_size(0), sizes(0),
|
||||
sizes(1), input.dim_size(3)}),
|
||||
&output));
|
||||
|
||||
const int64 batch_size = input.dim_size(0);
|
||||
const int64 in_height = input.dim_size(1);
|
||||
const int64 in_width = input.dim_size(2);
|
||||
const int64 channels = input.dim_size(3);
|
||||
|
||||
const int64 out_height = output->dim_size(1);
|
||||
const int64 out_width = output->dim_size(2);
|
||||
|
||||
typename TTypes<T, 4>::ConstTensor input_data = input.tensor<T, 4>();
|
||||
typename TTypes<T, 4>::Tensor output_data = output->tensor<T, 4>();
|
||||
|
||||
const float height_scale = out_height / static_cast<float>(in_height);
|
||||
const float width_scale = out_width / static_cast<float>(in_width);
|
||||
|
||||
for (int c = 0; c < channels; ++c) {
|
||||
for (int y = 0; y < out_height; ++y) {
|
||||
for (int x = 0; x < out_width; ++x) {
|
||||
for (int b = 0; b < batch_size; ++b) {
|
||||
output_data(b, y, x, c) = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int c = 0; c < channels; ++c) {
|
||||
for (int y = 0; y < in_height; ++y) {
|
||||
const int out_y = std::min(static_cast<int64>(floorf(y * height_scale)),
|
||||
(out_height - 1));
|
||||
|
||||
for (int x = 0; x < in_width; ++x) {
|
||||
const int out_x = std::min(
|
||||
static_cast<int64>(floorf(x * width_scale)), (out_width - 1));
|
||||
|
||||
for (int b = 0; b < batch_size; ++b) {
|
||||
output_data(b, out_y, out_x, c) += input_data(b, y, x, c);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#define REGISTER_KERNEL(T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("ResizeNearestNeighbor") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.HostMemory("size"), \
|
||||
ResizeNearestNeighborOp<CPUDevice, T>);
|
||||
ResizeNearestNeighborOp<CPUDevice, T>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("ResizeNearestNeighborGrad") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.HostMemory("size"), \
|
||||
ResizeNearestNeighborOpGrad<CPUDevice, T>);
|
||||
|
||||
REGISTER_KERNEL(uint8);
|
||||
REGISTER_KERNEL(int8);
|
||||
|
@ -141,8 +141,7 @@ void SparseSlice::Initialize(const ConstMatrixMap& mat, int col_offset) {
|
||||
Index idx;
|
||||
int data3_size = 0;
|
||||
for (int i = 0; i < num_blocks; ++i) {
|
||||
int num_block_cols =
|
||||
std::min(block_size, num_cols - block_size * (num_blocks - 1));
|
||||
int num_block_cols = std::min(block_size, num_cols - block_size * i);
|
||||
for (int row = 0; row < num_rows; ++row) {
|
||||
idx3.m = static_cast<uint8>(row);
|
||||
const float* start =
|
||||
@ -457,13 +456,11 @@ class SparseMatMulOp : public OpKernel {
|
||||
errors::InvalidArgument("b is not a matrix"));
|
||||
|
||||
auto left = a.matrix<float>();
|
||||
auto right_mat = b.matrix<float>();
|
||||
auto right = b.matrix<float>();
|
||||
const int m = transpose_a_ ? left.dimension(1) : left.dimension(0);
|
||||
const int k = transpose_a_ ? left.dimension(0) : left.dimension(1);
|
||||
const int n =
|
||||
transpose_b_ ? right_mat.dimension(0) : right_mat.dimension(1);
|
||||
const int k2 =
|
||||
transpose_b_ ? right_mat.dimension(1) : right_mat.dimension(0);
|
||||
const int n = transpose_b_ ? right.dimension(0) : right.dimension(1);
|
||||
const int k2 = transpose_b_ ? right.dimension(1) : right.dimension(0);
|
||||
|
||||
OP_REQUIRES(ctx, k == k2,
|
||||
errors::InvalidArgument("Matrix size incompatible: a: ",
|
||||
@ -473,7 +470,7 @@ class SparseMatMulOp : public OpKernel {
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({m, n}), &output));
|
||||
auto out = output->matrix<float>();
|
||||
|
||||
if (!a_is_sparse_) {
|
||||
if (!a_is_sparse_ && !b_is_sparse_) {
|
||||
// Fallback to Eigen contract.
|
||||
// Note that we currently don't optimize the case where only right is
|
||||
// sparse. That can generally be handled by tranposing the order of the
|
||||
@ -482,26 +479,41 @@ class SparseMatMulOp : public OpKernel {
|
||||
dim_pair[0].first = transpose_a_ ? 0 : 1;
|
||||
dim_pair[0].second = transpose_b_ ? 1 : 0;
|
||||
out.device(ctx->template eigen_device<CPUDevice>()) =
|
||||
left.contract(right_mat, dim_pair);
|
||||
left.contract(right, dim_pair);
|
||||
return;
|
||||
}
|
||||
auto left_mat = &left;
|
||||
auto right_mat = &right;
|
||||
bool transpose_output = false;
|
||||
bool transpose_a = transpose_a_;
|
||||
bool transpose_b = transpose_b_;
|
||||
if (!a_is_sparse_) {
|
||||
// Swap the order of multiplications using the identity:
|
||||
// A * B = (B' * A')'.
|
||||
std::swap(left_mat, right_mat);
|
||||
std::swap(transpose_a, transpose_b);
|
||||
transpose_a = !transpose_a;
|
||||
transpose_b = !transpose_b;
|
||||
transpose_output = !transpose_output;
|
||||
}
|
||||
std::unique_ptr<Matrix> right_tr_mat;
|
||||
std::unique_ptr<TTypes<float>::ConstMatrix> right_tr_map;
|
||||
if (transpose_b_) {
|
||||
if (transpose_b) {
|
||||
// TODO(agarwal): avoid transposing the matrix here and directly handle
|
||||
// transpose in CreateDenseSlices.
|
||||
right_tr_mat.reset(new Matrix(k, n));
|
||||
right_tr_mat.reset(
|
||||
new Matrix(right_mat->dimension(1), right_mat->dimension(0)));
|
||||
Eigen::array<int, 2> perm({1, 0});
|
||||
right_tr_mat->device(ctx->template eigen_device<CPUDevice>()) =
|
||||
right_mat.shuffle(perm);
|
||||
right_mat->shuffle(perm);
|
||||
right_tr_map.reset(new TTypes<float>::ConstMatrix(
|
||||
right_tr_mat->data(), right_tr_mat->dimensions()));
|
||||
right_mat = right_tr_map.get();
|
||||
}
|
||||
TTypes<float>::ConstMatrix& right =
|
||||
transpose_b_ ? *right_tr_map : right_mat;
|
||||
|
||||
SparseMatMul(left, right, transpose_a_,
|
||||
ctx->device()->tensorflow_cpu_worker_threads(), &out);
|
||||
SparseMatMul(*left_mat, *right_mat, transpose_a,
|
||||
ctx->device()->tensorflow_cpu_worker_threads(),
|
||||
transpose_output, &out);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -510,7 +522,7 @@ class SparseMatMulOp : public OpKernel {
|
||||
static inline void SparseMatMul(
|
||||
const ConstMatrixMap& left, const ConstMatrixMap& right,
|
||||
bool transpose_left, const DeviceBase::CpuWorkerThreads* thread_pool,
|
||||
MatrixMap* output);
|
||||
bool transpose_output, MatrixMap* output);
|
||||
|
||||
// Computes multiplication of left and num_cols columns of right, and stores
|
||||
// the output block in *"output" at offsets "output_row_offset" and
|
||||
@ -518,9 +530,9 @@ class SparseMatMulOp : public OpKernel {
|
||||
// else adds the values to the existing values.
|
||||
static inline void ComputeOutputBlock(const std::vector<SparseSlice*>& left,
|
||||
const ConstMatrixMap& right,
|
||||
const int num_cols,
|
||||
int output_row_offset,
|
||||
int num_cols, int output_row_offset,
|
||||
int output_col_offset, bool assign,
|
||||
bool transpose_output,
|
||||
MatrixMap* output);
|
||||
|
||||
// Encodes "mat" using a sparse representation and stores that in
|
||||
@ -578,9 +590,10 @@ class SparseMatMulOp : public OpKernel {
|
||||
|
||||
inline void SparseMatMulOp::ComputeOutputBlock(
|
||||
const std::vector<SparseSlice*>& left, const ConstMatrixMap& right,
|
||||
const int num_cols, int output_row_offset, int output_col_offset,
|
||||
bool assign, MatrixMap* output) {
|
||||
const int num_rows = left[0]->num_rows;
|
||||
int num_cols, int output_row_offset, int output_col_offset, bool assign,
|
||||
bool transpose_output, MatrixMap* output) {
|
||||
static const Eigen::array<int, 2> perm({1, 0});
|
||||
int num_rows = left[0]->num_rows;
|
||||
const int rhs_num_cols = right.dimension(1);
|
||||
DCHECK_LE(num_cols, rhs_num_cols);
|
||||
Matrix out(num_rows, rhs_num_cols);
|
||||
@ -593,18 +606,33 @@ inline void SparseMatMulOp::ComputeOutputBlock(
|
||||
if (!assign) {
|
||||
const Eigen::array<int, 2> begin = {output_row_offset, output_col_offset};
|
||||
const Eigen::array<int, 2> sizes = {num_rows, num_cols};
|
||||
if (transpose_output) {
|
||||
if (num_cols == rhs_num_cols) {
|
||||
output->shuffle(perm).slice(begin, sizes) += out;
|
||||
} else {
|
||||
static const Eigen::array<int, 2> zero = {0, 0};
|
||||
output->shuffle(perm).slice(begin, sizes) += out.slice(zero, sizes);
|
||||
}
|
||||
} else {
|
||||
if (num_cols == rhs_num_cols) {
|
||||
output->slice(begin, sizes) += out;
|
||||
} else {
|
||||
static const Eigen::array<int, 2> zero = {0, 0};
|
||||
output->slice(begin, sizes) += out.slice(zero, sizes);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// output->slice(begin, sizes) = out.slice(zero, sizes), implemented
|
||||
// using memcpy.
|
||||
std::unique_ptr<Matrix> out_tr;
|
||||
if (transpose_output) {
|
||||
out_tr.reset(new Matrix(rhs_num_cols, num_rows));
|
||||
*out_tr = out.shuffle(perm);
|
||||
std::swap(output_row_offset, output_col_offset);
|
||||
std::swap(num_rows, num_cols);
|
||||
}
|
||||
const Matrix& final_out = transpose_output ? *out_tr : out;
|
||||
for (int i = 0; i < num_rows; ++i) {
|
||||
memcpy(&(*output)(output_row_offset + i, output_col_offset), &out(i, 0),
|
||||
num_cols * sizeof(float));
|
||||
memcpy(&(*output)(output_row_offset + i, output_col_offset),
|
||||
&final_out(i, 0), num_cols * sizeof(float));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -807,7 +835,7 @@ inline void SparseMatMulOp::ComputeBlockSizes(const ConstMatrixMap& left,
|
||||
inline void SparseMatMulOp::SparseMatMul(
|
||||
const ConstMatrixMap& left, const ConstMatrixMap& right,
|
||||
bool transpose_left, const DeviceBase::CpuWorkerThreads* thread_pool,
|
||||
MatrixMap* output) {
|
||||
bool transpose_output, MatrixMap* output) {
|
||||
const int num_threads = thread_pool->num_threads;
|
||||
int KR, NR, KL, JB, IB;
|
||||
ComputeBlockSizes(left, right, transpose_left, num_threads, &KR, &NR, &KL,
|
||||
@ -868,7 +896,7 @@ inline void SparseMatMulOp::SparseMatMul(
|
||||
tasks.push_back(std::bind(
|
||||
&SparseMatMulOp::ComputeOutputBlock, block_left_slices,
|
||||
std::ref(*right_slices[j_inner]), num_cols, M * i_inner,
|
||||
N * j_inner + nb * NR, kb == 0, output));
|
||||
N * j_inner + nb * NR, kb == 0, transpose_output, output));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -29,8 +29,11 @@ random::SimplePhilox rnd(&philox);
|
||||
void Sparsify(Tensor* t, float sparsity) {
|
||||
const int64 N = t->NumElements();
|
||||
CHECK_LE(sparsity, 1);
|
||||
if (sparsity <= 0) return;
|
||||
auto flat = t->flat<float>();
|
||||
if (sparsity == 1) {
|
||||
flat.setZero();
|
||||
return;
|
||||
}
|
||||
static const uint32 K = 10000;
|
||||
for (int64 i = 0; i < N; ++i) {
|
||||
if (rnd.Uniform(K) < sparsity * K) {
|
||||
@ -55,25 +58,21 @@ Node* SparseMatMulNode(Graph* g, Node* in0, Node* in1, bool transpose_a,
|
||||
return ret;
|
||||
}
|
||||
|
||||
static Graph* SparseMatMulHelper(Graph* g, int m, int n, int d, float sparsity,
|
||||
bool transpose_a, bool transpose_b,
|
||||
bool a_sparse, bool b_sparse) {
|
||||
a_sparse = a_sparse && (sparsity > 0);
|
||||
b_sparse = b_sparse && (sparsity > 0);
|
||||
static Graph* SparseMatMulHelper(Graph* g, int m, int n, int d,
|
||||
float sparsity_a, float sparsity_b,
|
||||
bool transpose_a, bool transpose_b) {
|
||||
bool a_sparse = (sparsity_a > 0);
|
||||
bool b_sparse = (sparsity_b > 0);
|
||||
|
||||
auto left_shape = transpose_a ? TensorShape({d, m}) : TensorShape({m, d});
|
||||
Tensor left(DataTypeToEnum<float>::value, left_shape);
|
||||
left.flat<float>().setRandom();
|
||||
if (a_sparse) {
|
||||
Sparsify(&left, sparsity);
|
||||
}
|
||||
Sparsify(&left, sparsity_a);
|
||||
|
||||
auto right_shape = transpose_b ? TensorShape({n, d}) : TensorShape({d, n});
|
||||
Tensor right(DataTypeToEnum<float>::value, right_shape);
|
||||
right.flat<float>().setRandom();
|
||||
if (b_sparse) {
|
||||
Sparsify(&right, sparsity);
|
||||
}
|
||||
Sparsify(&right, sparsity_b);
|
||||
|
||||
SparseMatMulNode(g, test::graph::Constant(g, left),
|
||||
test::graph::Constant(g, right), transpose_a, transpose_b,
|
||||
@ -81,58 +80,61 @@ static Graph* SparseMatMulHelper(Graph* g, int m, int n, int d, float sparsity,
|
||||
return g;
|
||||
}
|
||||
|
||||
static Graph* SparseMatMul(int m, int n, int d, float sparsity,
|
||||
bool transpose_a, bool transpose_b) {
|
||||
static Graph* SparseMatMul(int m, int n, int d, float sparsity_a,
|
||||
float sparsity_b, bool transpose_a,
|
||||
bool transpose_b) {
|
||||
Graph* g = new Graph(OpRegistry::Global());
|
||||
return SparseMatMulHelper(g, m, n, d, sparsity, transpose_a, transpose_b,
|
||||
true, false);
|
||||
return SparseMatMulHelper(g, m, n, d, sparsity_a, sparsity_b, transpose_a,
|
||||
transpose_b);
|
||||
}
|
||||
|
||||
static Graph* MultiSparseMatMul(int m, int n, int d, float sparsity_a,
|
||||
float sparsity_b) {
|
||||
Graph* g = new Graph(OpRegistry::Global());
|
||||
if (sparsity_a == 0 && sparsity_b > 0) {
|
||||
SparseMatMulHelper(g, m, n, d, sparsity_a, false, false, false, false);
|
||||
SparseMatMulHelper(g, n, d, m, sparsity_b, true, true, true, false);
|
||||
SparseMatMulHelper(g, m, d, n, sparsity_b, false, false, true, false);
|
||||
} else {
|
||||
SparseMatMulHelper(g, m, n, d, sparsity_a, false, true, true, false);
|
||||
SparseMatMulHelper(g, d, n, m, sparsity_a, true, false, true, true);
|
||||
SparseMatMulHelper(g, m, d, n, sparsity_b, false, false, true, false);
|
||||
}
|
||||
return g;
|
||||
}
|
||||
|
||||
#define BM_SPARSE(M, K, N, S) \
|
||||
static void BM_Sparse##_##M##_##K##_##N##_##S(int iters) { \
|
||||
#define BM_SPARSE(M, K, N, S1, S2, TA, TB) \
|
||||
static void BM_Sparse##_##M##_##K##_##N##_##S1##_##S2##_##TA##_##TB( \
|
||||
int iters) { \
|
||||
testing::StopTiming(); \
|
||||
testing::ItemsProcessed(static_cast<int64>(iters) * M * K * N * 2); \
|
||||
std::string label; \
|
||||
if (S == 0) { \
|
||||
label = strings::Printf("%d*%d*%d_Eigen", M, K, N); \
|
||||
} else { \
|
||||
label = strings::Printf("%d*%d*%d_sparsity:%0.2f", M, K, N, S / 100.0); \
|
||||
} \
|
||||
std::string label = \
|
||||
strings::Printf("tr_a: %d tr_b: %d sp_a: %0.2f sp_b: %0.2f", TA, TB, \
|
||||
S1 / 100.0, S2 / 100.0); \
|
||||
testing::SetLabel(label); \
|
||||
testing::UseRealTime(); \
|
||||
auto g = SparseMatMul(M, N, K, S / 100.0, false, false); \
|
||||
auto g = SparseMatMul(M, N, K, S1 / 100.0, S2 / 100.0, TA, TB); \
|
||||
testing::StartTiming(); \
|
||||
test::Benchmark("cpu", g).Run(iters); \
|
||||
} \
|
||||
BENCHMARK(BM_Sparse##_##M##_##K##_##N##_##S);
|
||||
BENCHMARK(BM_Sparse##_##M##_##K##_##N##_##S1##_##S2##_##TA##_##TB);
|
||||
|
||||
BM_SPARSE(2048, 2048, 2048, 0);
|
||||
BM_SPARSE(2048, 2048, 2048, 1);
|
||||
BM_SPARSE(2048, 2048, 2048, 50);
|
||||
BM_SPARSE(2048, 2048, 2048, 85);
|
||||
BM_SPARSE(2048, 2048, 2048, 99);
|
||||
BM_SPARSE(2048, 2048, 2048, 0, 0, false, false);
|
||||
BM_SPARSE(2048, 2048, 2048, 1, 0, false, false);
|
||||
BM_SPARSE(2048, 2048, 2048, 50, 0, false, false);
|
||||
BM_SPARSE(2048, 2048, 2048, 85, 0, false, false);
|
||||
BM_SPARSE(2048, 2048, 2048, 99, 0, false, false);
|
||||
|
||||
BM_SPARSE(1024, 1024, 1024, 0);
|
||||
BM_SPARSE(1024, 1024, 1024, 1);
|
||||
BM_SPARSE(1024, 1024, 1024, 85);
|
||||
BM_SPARSE(2048, 2048, 2048, 0, 50, false, false);
|
||||
BM_SPARSE(2048, 2048, 2048, 0, 85, false, false);
|
||||
|
||||
BM_SPARSE(256, 256, 256, 1);
|
||||
BM_SPARSE(512, 512, 512, 1);
|
||||
BM_SPARSE(2048, 2048, 2048, 85, 0, true, false);
|
||||
BM_SPARSE(2048, 2048, 2048, 85, 0, false, true);
|
||||
BM_SPARSE(2048, 2048, 2048, 85, 0, true, true);
|
||||
|
||||
BM_SPARSE(2048, 2048, 2048, 0, 85, true, false);
|
||||
BM_SPARSE(2048, 2048, 2048, 0, 85, false, true);
|
||||
BM_SPARSE(2048, 2048, 2048, 0, 85, true, true);
|
||||
|
||||
BM_SPARSE(1024, 1024, 1024, 0, 0, false, false);
|
||||
BM_SPARSE(1024, 1024, 1024, 1, 0, false, false);
|
||||
BM_SPARSE(1024, 1024, 1024, 85, 0, false, false);
|
||||
|
||||
BM_SPARSE(256, 256, 256, 1, 0, false, false);
|
||||
BM_SPARSE(512, 512, 512, 1, 0, false, false);
|
||||
|
||||
static Graph* MultiSparseMatMul(int m, int n, int d, float sparsity_1,
|
||||
float sparsity_2) {
|
||||
Graph* g = new Graph(OpRegistry::Global());
|
||||
SparseMatMulHelper(g, d, n, m, sparsity_1, sparsity_2, true, false);
|
||||
SparseMatMulHelper(g, m, d, n, sparsity_2, 0, false, true);
|
||||
return g;
|
||||
}
|
||||
|
||||
#define BM_SPARSE_MULTI(M, K, N, S1, S2) \
|
||||
static void BM_Sparse_Multi##_##M##_##K##_##N##_##S1##_##S2(int iters) { \
|
||||
@ -151,22 +153,4 @@ BM_SPARSE(512, 512, 512, 1);
|
||||
BM_SPARSE_MULTI(1024, 2140, 4096, 0, 82);
|
||||
BM_SPARSE_MULTI(1024, 4096, 2048, 83, 83);
|
||||
|
||||
#define BM_SPARSE_TR(M, K, N, S, TA, TB) \
|
||||
static void BM_Sparse##_##M##_##K##_##N##_##S##_##TA##_##TB(int iters) { \
|
||||
testing::StopTiming(); \
|
||||
testing::ItemsProcessed(static_cast<int64>(iters) * M * K * N * 2); \
|
||||
std::string label = \
|
||||
strings::Printf("%d_%d_%d_%d_%d_%0.2f", M, K, N, TA, TB, S / 100.0); \
|
||||
testing::SetLabel(label); \
|
||||
testing::UseRealTime(); \
|
||||
auto g = SparseMatMul(M, N, K, S / 100.0, TA, TB); \
|
||||
testing::StartTiming(); \
|
||||
test::Benchmark("cpu", g).Run(iters); \
|
||||
} \
|
||||
BENCHMARK(BM_Sparse##_##M##_##K##_##N##_##S##_##TA##_##TB);
|
||||
|
||||
BM_SPARSE_TR(2048, 2048, 2048, 1, true, false);
|
||||
BM_SPARSE_TR(2048, 2048, 2048, 1, false, true);
|
||||
BM_SPARSE_TR(2048, 2048, 2048, 1, true, true);
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
181
tensorflow/core/kernels/stack_ops.cc
Normal file
@ -0,0 +1,181 @@
|
||||
// See docs in ../ops/data_flow_ops.cc.
|
||||
|
||||
#include <limits.h>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/port.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/core/public/tensor.h"
|
||||
#include "tensorflow/core/public/tensor_shape.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class Stack : public ResourceBase {
|
||||
public:
|
||||
Stack(const DataType& elem_type, const Tensor& handle)
|
||||
: elem_type_(elem_type), handle_(handle) {}
|
||||
|
||||
void Push(const PersistentTensor& value) {
|
||||
mutex_lock l(mu_);
|
||||
stack_.push_back(value);
|
||||
}
|
||||
|
||||
bool Pop(PersistentTensor* value) {
|
||||
mutex_lock l(mu_);
|
||||
if (!stack_.empty()) {
|
||||
*value = stack_.back();
|
||||
stack_.pop_back();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
DataType ElemType() { return elem_type_; }
|
||||
|
||||
string DebugString() override {
|
||||
mutex_lock l(mu_);
|
||||
return strings::StrCat("#elem:", stack_.size());
|
||||
}
|
||||
|
||||
private:
|
||||
friend class StackOp;
|
||||
mutex* mu() { return &mu_; }
|
||||
Tensor* handle() { return &handle_; }
|
||||
|
||||
mutex mu_;
|
||||
DataType elem_type_;
|
||||
Tensor handle_;
|
||||
std::vector<PersistentTensor> stack_ GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
// A per-run local stack. The stack uses a "per-step" resource manager which
|
||||
// ensures that correct garbage collection on error or successful completion.
|
||||
class StackOp : public OpKernel {
|
||||
public:
|
||||
explicit StackOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("elem_type", &elem_type_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("stack_name", &stack_name_));
|
||||
if (stack_name_ == "") stack_name_ = name();
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
// Create the stack handle.
|
||||
Tensor stack_handle;
|
||||
AllocatorAttributes alloc_attr;
|
||||
alloc_attr.set_on_host(true);
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_temp(tensorflow::DT_STRING,
|
||||
tensorflow::TensorShape({2}),
|
||||
&stack_handle, alloc_attr));
|
||||
auto handle = stack_handle.flat<string>();
|
||||
handle(0) = "_stacks";
|
||||
handle(1) = stack_name_;
|
||||
// Store the handle in a container of the per-step RM.
|
||||
ResourceMgr* rm = ctx->step_resource_manager();
|
||||
OP_REQUIRES(ctx, rm != nullptr,
|
||||
errors::Internal("No per-step resource manager."));
|
||||
Stack* stack = new Stack(elem_type_, stack_handle);
|
||||
OP_REQUIRES_OK(ctx, rm->Create(handle(0), stack_name_, stack));
|
||||
ctx->set_output_ref(0, stack->mu(), stack->handle());
|
||||
}
|
||||
|
||||
private:
|
||||
DataType elem_type_;
|
||||
string stack_name_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(StackOp);
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("Stack").Device(DEVICE_CPU), StackOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("Stack").Device(DEVICE_GPU).HostMemory("handle"),
|
||||
StackOp);
|
||||
|
||||
class StackPushOp : public OpKernel {
|
||||
public:
|
||||
explicit StackPushOp(OpKernelConstruction* context) : OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
Tensor Tstack_handle = ctx->mutable_input(0, false);
|
||||
OP_REQUIRES(ctx, Tstack_handle.NumElements() == 2,
|
||||
errors::InvalidArgument(
|
||||
"Stack handle must have two elements, but had shape: ",
|
||||
Tstack_handle.shape().DebugString()));
|
||||
const string& container = Tstack_handle.flat<string>()(0);
|
||||
const string& stack_name = Tstack_handle.flat<string>()(1);
|
||||
ResourceMgr* rm = ctx->step_resource_manager();
|
||||
OP_REQUIRES(ctx, rm != nullptr,
|
||||
errors::Internal("No per-step resource manager."));
|
||||
Stack* stack = nullptr;
|
||||
OP_REQUIRES_OK(ctx, rm->Lookup(container, stack_name, &stack));
|
||||
OP_REQUIRES(ctx, ctx->input_dtype(1) == stack->ElemType(),
|
||||
errors::InvalidArgument("Must have type ", stack->ElemType(),
|
||||
" but got ", ctx->input_dtype(1)));
|
||||
stack->Push(PersistentTensor(ctx->input(1)));
|
||||
ctx->set_output(0, ctx->input(1));
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("StackPush").Device(DEVICE_CPU), StackPushOp);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("StackPush").Device(DEVICE_GPU).HostMemory("handle"), StackPushOp);
|
||||
|
||||
class StackPopOp : public OpKernel {
|
||||
public:
|
||||
explicit StackPopOp(OpKernelConstruction* context) : OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
Tensor Tstack_handle = ctx->mutable_input(0, false);
|
||||
OP_REQUIRES(ctx, Tstack_handle.NumElements() == 2,
|
||||
errors::InvalidArgument(
|
||||
"Stack handle must have two elements, but had shape: ",
|
||||
Tstack_handle.shape().DebugString()));
|
||||
const string& container = Tstack_handle.flat<string>()(0);
|
||||
const string& stack_name = Tstack_handle.flat<string>()(1);
|
||||
ResourceMgr* rm = ctx->step_resource_manager();
|
||||
OP_REQUIRES(ctx, rm != nullptr,
|
||||
errors::Internal("No per-step resource manager."));
|
||||
Stack* stack = nullptr;
|
||||
OP_REQUIRES_OK(ctx, rm->Lookup(container, stack_name, &stack));
|
||||
PersistentTensor value;
|
||||
bool has_value = stack->Pop(&value);
|
||||
if (!has_value) {
|
||||
errors::InvalidArgument("Calling Pop() when the stack is empty.");
|
||||
}
|
||||
ctx->set_output(0, *value.AccessTensor(ctx));
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("StackPop").Device(DEVICE_CPU), StackPopOp);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("StackPop").Device(DEVICE_GPU).HostMemory("handle"), StackPopOp);
|
||||
|
||||
class StackCloseOp : public OpKernel {
|
||||
public:
|
||||
explicit StackCloseOp(OpKernelConstruction* context) : OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
Tensor Tstack_handle = ctx->mutable_input(0, false);
|
||||
OP_REQUIRES(ctx, Tstack_handle.NumElements() == 2,
|
||||
errors::InvalidArgument(
|
||||
"Stack handle must have two elements, but had shape: ",
|
||||
Tstack_handle.shape().DebugString()));
|
||||
const string& container = Tstack_handle.flat<string>()(0);
|
||||
const string& stack_name = Tstack_handle.flat<string>()(1);
|
||||
ResourceMgr* rm = ctx->step_resource_manager();
|
||||
OP_REQUIRES(ctx, rm != nullptr,
|
||||
errors::Internal("No per-step resource manager."));
|
||||
OP_REQUIRES_OK(ctx, rm->Delete<Stack>(container, stack_name));
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("StackClose").Device(DEVICE_CPU), StackCloseOp);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("StackClose").Device(DEVICE_GPU).HostMemory("handle"), StackCloseOp);
|
||||
|
||||
} // namespace tensorflow
|
@ -40,7 +40,8 @@ class StringToHashBucketOp : public OpKernel {
|
||||
&output_tensor));
|
||||
auto output_flat = output_tensor->flat<int64>();
|
||||
|
||||
for (size_t i = 0; i < input_flat.size(); ++i) {
|
||||
typedef decltype(input_flat.size()) Index;
|
||||
for (Index i = 0; i < input_flat.size(); ++i) {
|
||||
const uint64 input_hash = Hash64(input_flat(i));
|
||||
const uint64 bucket_id = input_hash % num_buckets_;
|
||||
// The number of buckets is always in the positive range of int64 so is
|
||||
|
@ -1,7 +1,18 @@
|
||||
// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style license that can be
|
||||
// found in the LICENSE file. See the AUTHORS file for names of contributors.
|
||||
//
|
||||
/* Copyright 2015 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Decodes the blocks generated by block_builder.cc.
|
||||
|
||||
#include "tensorflow/core/lib/io/block.h"
|
||||
|
@ -1,6 +1,17 @@
|
||||
// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style license that can be
|
||||
// found in the LICENSE file. See the AUTHORS file for names of contributors.
|
||||
/* Copyright 2015 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_LIB_IO_BLOCK_H_
|
||||
#define TENSORFLOW_LIB_IO_BLOCK_H_
|
||||
|
@ -1,7 +1,18 @@
|
||||
// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style license that can be
|
||||
// found in the LICENSE file. See the AUTHORS file for names of contributors.
|
||||
//
|
||||
/* Copyright 2015 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// BlockBuilder generates blocks where keys are prefix-compressed:
|
||||
//
|
||||
// When we store a key, we drop the prefix shared with the previous
|
||||
|
@ -1,6 +1,17 @@
|
||||
// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style license that can be
|
||||
// found in the LICENSE file. See the AUTHORS file for names of contributors.
|
||||
/* Copyright 2015 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_LIB_IO_BLOCK_BUILDER_H_
|
||||
#define TENSORFLOW_LIB_IO_BLOCK_BUILDER_H_
|
||||
|
@ -1,6 +1,17 @@
|
||||
// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style license that can be
|
||||
// found in the LICENSE file. See the AUTHORS file for names of contributors.
|
||||
/* Copyright 2015 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/lib/io/format.h"
|
||||
|
||||
|
@ -1,6 +1,17 @@
|
||||
// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style license that can be
|
||||
// found in the LICENSE file. See the AUTHORS file for names of contributors.
|
||||
/* Copyright 2015 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_LIB_IO_FORMAT_H_
|
||||
#define TENSORFLOW_LIB_IO_FORMAT_H_
|
||||
|
@ -1,6 +1,17 @@
|
||||
// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style license that can be
|
||||
// found in the LICENSE file. See the AUTHORS file for names of contributors.
|
||||
/* Copyright 2015 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/lib/io/iterator.h"
|
||||
|
||||
|
@ -1,7 +1,18 @@
|
||||
// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style license that can be
|
||||
// found in the LICENSE file. See the AUTHORS file for names of contributors.
|
||||
//
|
||||
/* Copyright 2015 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// An iterator yields a sequence of key/value pairs from a source.
|
||||
// The following class defines the interface. Multiple implementations
|
||||
// are provided by this library. In particular, iterators are provided
|
||||
|
@ -1,6 +1,17 @@
|
||||
// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style license that can be
|
||||
// found in the LICENSE file. See the AUTHORS file for names of contributors.
|
||||
/* Copyright 2015 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/lib/io/table.h"
|
||||
|
||||
|
@ -1,6 +1,17 @@
|
||||
// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style license that can be
|
||||
// found in the LICENSE file. See the AUTHORS file for names of contributors.
|
||||
/* Copyright 2015 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_LIB_IO_TABLE_H_
|
||||
#define TENSORFLOW_LIB_IO_TABLE_H_
|
||||
|
@ -1,6 +1,17 @@
|
||||
// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style license that can be
|
||||
// found in the LICENSE file. See the AUTHORS file for names of contributors.
|
||||
/* Copyright 2015 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/lib/io/table_builder.h"
|
||||
|
||||
|
@ -1,7 +1,18 @@
|
||||
// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style license that can be
|
||||
// found in the LICENSE file. See the AUTHORS file for names of contributors.
|
||||
//
|
||||
/* Copyright 2015 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// TableBuilder provides the interface used to build a Table
|
||||
// (an immutable and sorted map from keys to values).
|
||||
//
|
||||
|
@ -1,8 +1,8 @@
|
||||
File format
|
||||
===========
|
||||
|
||||
The table format is heavily based on the table format for the LevelDB
|
||||
The table format is similar to the table format for the LevelDB
|
||||
open source key/value store, with the exception that our tables
|
||||
do not support "filter" meta blocks (Bloom Filters). See:
|
||||
|
||||
https://code.google.com/p/leveldb/source/browse/doc/table_format.txt
|
||||
https://github.com/google/leveldb/blob/master/doc/table_format.txt
|
||||
|
@ -1,6 +1,17 @@
|
||||
// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style license that can be
|
||||
// found in the LICENSE file. See the AUTHORS file for names of contributors.
|
||||
/* Copyright 2015 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/lib/io/table.h"
|
||||
|
||||
|
@ -1,6 +1,17 @@
|
||||
// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style license that can be
|
||||
// found in the LICENSE file. See the AUTHORS file for names of contributors.
|
||||
/* Copyright 2015 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/lib/io/two_level_iterator.h"
|
||||
|
||||
|
@ -1,6 +1,17 @@
|
||||
// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style license that can be
|
||||
// found in the LICENSE file. See the AUTHORS file for names of contributors.
|
||||
/* Copyright 2015 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_LIB_IO_TWO_LEVEL_ITERATOR_H_
|
||||
#define TENSORFLOW_LIB_IO_TWO_LEVEL_ITERATOR_H_
|
||||
|
@ -301,6 +301,55 @@ size: The number of elements in the given queue.
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
REGISTER_OP("Stack")
|
||||
.Output("handle: Ref(string)")
|
||||
.Attr("elem_type: type")
|
||||
.Attr("stack_name: string = ''")
|
||||
.SetIsStateful()
|
||||
.Doc(R"doc(
|
||||
A stack that produces elements in first-in last-out order.
|
||||
|
||||
handle: The handle to the stack.
|
||||
elem_type: The type of the elements on the stack.
|
||||
stack_name: Overrides the name used for the temporary stack resource. Default
|
||||
value is the name of the 'Stack' op (which is guaranteed unique).
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("StackPush")
|
||||
.Input("handle: Ref(string)")
|
||||
.Input("elem: T")
|
||||
.Output("output: T")
|
||||
.Attr("T: type")
|
||||
.Doc(R"doc(
|
||||
Push an element onto the stack.
|
||||
|
||||
handle: The handle to a stack.
|
||||
elem: The tensor to be pushed onto the stack.
|
||||
output: The same tensor as the input 'elem'.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("StackPop")
|
||||
.Input("handle: Ref(string)")
|
||||
.Output("elem: elem_type")
|
||||
.Attr("elem_type: type")
|
||||
.Doc(R"doc(
|
||||
Pop the element at the top of the stack.
|
||||
|
||||
handle: The handle to a stack.
|
||||
elem_type: The type of the elem that is popped.
|
||||
elem: The tensor that is popped from the top of the stack.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("StackClose")
|
||||
.Input("handle: Ref(string)")
|
||||
.Doc(R"doc(
|
||||
Delete the stack from its resource container.
|
||||
|
||||
handle: The handle to a stack.
|
||||
)doc");
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
REGISTER_OP("LookupTableFind")
|
||||
.Input("table_handle: Ref(string)")
|
||||
.Input("keys: Tin")
|
||||
|
@ -89,6 +89,22 @@ resized_images: 4-D with shape
|
||||
`[batch, new_height, new_width, channels]`.
|
||||
)doc");
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
REGISTER_OP("ResizeNearestNeighborGrad")
|
||||
.Input("grads: T")
|
||||
.Input("size: int32")
|
||||
.Output("output: T")
|
||||
.Attr("T: {uint8, int8, int32, float, double}")
|
||||
.Doc(R"doc(
|
||||
Computes the gradient of nearest neighbor interpolation.
|
||||
|
||||
grads: 4-D with shape `[batch, height, width, channels]`.
|
||||
size:= A 1-D int32 Tensor of 2 elements: `orig_height, orig_width`. The
|
||||
original input size.
|
||||
output: 4-D with shape `[batch, orig_height, orig_width, channels]`. Gradients
|
||||
with respect to the input image.
|
||||
)doc");
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
REGISTER_OP("RandomCrop")
|
||||
.Input("image: T")
|
||||
|
@ -655,7 +655,7 @@ that `segment_ids[j] == i`.
|
||||
segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
|
||||
first dimension. Values should be sorted and can be repeated.
|
||||
|
||||
output: Has same shape as data, except for dimension_0 which
|
||||
output: Has same shape as data, except for dimension 0 which
|
||||
has size `k`, the number of segments.
|
||||
)doc");
|
||||
|
||||
@ -684,7 +684,7 @@ values summed.
|
||||
segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
|
||||
first dimension. Values should be sorted and can be repeated.
|
||||
|
||||
output: Has same shape as data, except for dimension_0 which
|
||||
output: Has same shape as data, except for dimension 0 which
|
||||
has size `k`, the number of segments.
|
||||
)doc");
|
||||
|
||||
@ -712,7 +712,7 @@ that `segment_ids[j] == i`.
|
||||
segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
|
||||
first dimension. Values should be sorted and can be repeated.
|
||||
|
||||
output: Has same shape as data, except for dimension_0 which
|
||||
output: Has same shape as data, except for dimension 0 which
|
||||
has size `k`, the number of segments.
|
||||
)doc");
|
||||
|
||||
@ -740,7 +740,7 @@ that `segment_ids[j] == i`.
|
||||
segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
|
||||
first dimension. Values should be sorted and can be repeated.
|
||||
|
||||
output: Has same shape as data, except for dimension_0 which
|
||||
output: Has same shape as data, except for dimension 0 which
|
||||
has size `k`, the number of segments.
|
||||
)doc");
|
||||
|
||||
@ -767,7 +767,7 @@ that `segment_ids[j] == i`.
|
||||
segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
|
||||
first dimension. Values should be sorted and can be repeated.
|
||||
|
||||
output: Has same shape as data, except for dimension_0 which
|
||||
output: Has same shape as data, except for dimension 0 which
|
||||
has size `k`, the number of segments.
|
||||
)doc");
|
||||
|
||||
@ -802,7 +802,7 @@ If the sum is empty for a given segment ID `i`, `output[i] = 0`.
|
||||
segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
|
||||
first dimension.
|
||||
|
||||
output: Has same shape as data, except for dimension_0 which
|
||||
output: Has same shape as data, except for dimension 0 which
|
||||
has size `num_segments`.
|
||||
|
||||
)doc");
|
||||
@ -821,7 +821,7 @@ Segmentation](../../api_docs/python/math_ops.md#segmentation) for an explanation
|
||||
of segments.
|
||||
|
||||
Like `SegmentSum`, but `segment_ids` can have rank less than `data`'s first
|
||||
dimension, selecting a subset of dimension_0, specified by `indices`.
|
||||
dimension, selecting a subset of dimension 0, specified by `indices`.
|
||||
|
||||
For example:
|
||||
|
||||
@ -850,7 +850,7 @@ indices: A 1-D tensor. Has same rank as `segment_ids`.
|
||||
|
||||
segment_ids: A 1-D tensor. Values should be sorted and can be repeated.
|
||||
|
||||
output: Has same shape as data, except for dimension_0 which
|
||||
output: Has same shape as data, except for dimension 0 which
|
||||
has size `k`, the number of segments.
|
||||
)doc");
|
||||
|
||||
@ -868,13 +868,13 @@ Segmentation](../../api_docs/python/math_ops.md#segmentation) for an explanation
|
||||
of segments.
|
||||
|
||||
Like `SegmentMean`, but `segment_ids` can have rank less than `data`'s first
|
||||
dimension, selecting a subset of dimension_0, specified by `indices`.
|
||||
dimension, selecting a subset of dimension 0, specified by `indices`.
|
||||
|
||||
indices: A 1-D tensor. Has same rank as `segment_ids`.
|
||||
|
||||
segment_ids: A 1-D tensor. Values should be sorted and can be repeated.
|
||||
|
||||
output: Has same shape as data, except for dimension_0 which
|
||||
output: Has same shape as data, except for dimension 0 which
|
||||
has size `k`, the number of segments.
|
||||
|
||||
)doc");
|
||||
@ -889,13 +889,13 @@ REGISTER_OP("SparseSegmentMeanGrad")
|
||||
.Doc(R"doc(
|
||||
Computes gradients for SparseSegmentMean.
|
||||
|
||||
Returns tensor "output" with same shape as grad, except for dimension_0 whose
|
||||
Returns tensor "output" with same shape as grad, except for dimension 0 whose
|
||||
value is output_dim0.
|
||||
|
||||
grad: gradient propagated to the SparseSegmentMean op.
|
||||
indices: indices passed to the corresponding SparseSegmentMean op.
|
||||
segment_ids: segment_ids passed to the corresponding SparseSegmentMean op.
|
||||
output_dim0: dimension_0 of "data" passed to SparseSegmentMean op.
|
||||
output_dim0: dimension 0 of "data" passed to SparseSegmentMean op.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("All")
|
||||
|
@ -1874,7 +1874,7 @@ op {
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
description: "A Tensor with one more dimension than the input bytes. The\nadded dimension will have size equal to the length of the elements\nof bytes divided by the number of bytes to represent out_type."
|
||||
description: "A Tensor with one more dimension than the input `bytes`. The\nadded dimension will have size equal to the length of the elements\nof `bytes` divided by the number of bytes to represent `out_type`."
|
||||
type_attr: "out_type"
|
||||
}
|
||||
attr {
|
||||
@ -1898,7 +1898,7 @@ op {
|
||||
default_value {
|
||||
b: true
|
||||
}
|
||||
description: "Whether the input bytes are in little-endian order.\nIgnored for out_types that are stored in a single byte like uint8."
|
||||
description: "Whether the input `bytes` are in little-endian order.\nIgnored for `out_type` values that are stored in a single byte like\n`uint8`."
|
||||
}
|
||||
summary: "Reinterpret the bytes of a string as a vector of numbers."
|
||||
}
|
||||
@ -5666,6 +5666,38 @@ op {
|
||||
summary: "Resize `images` to `size` using nearest neighbor interpolation."
|
||||
description: "Input images can be of different types but output images are always float."
|
||||
}
|
||||
op {
|
||||
name: "ResizeNearestNeighborGrad"
|
||||
input_arg {
|
||||
name: "grads"
|
||||
description: "4-D with shape `[batch, height, width, channels]`."
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "size"
|
||||
description: "= A 1-D int32 Tensor of 2 elements: `orig_height, orig_width`. The\noriginal input size."
|
||||
type: DT_INT32
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
description: "4-D with shape `[batch, orig_height, orig_width, channels]`. Gradients\nwith respect to the input image."
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_UINT8
|
||||
type: DT_INT8
|
||||
type: DT_INT32
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
}
|
||||
}
|
||||
}
|
||||
summary: "Computes the gradient of nearest neighbor interpolation."
|
||||
}
|
||||
op {
|
||||
name: "Restore"
|
||||
input_arg {
|
||||
@ -6125,7 +6157,7 @@ op {
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
description: "Has same shape as data, except for dimension_0 which\nhas size `k`, the number of segments."
|
||||
description: "Has same shape as data, except for dimension 0 which\nhas size `k`, the number of segments."
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
@ -6169,7 +6201,7 @@ op {
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
description: "Has same shape as data, except for dimension_0 which\nhas size `k`, the number of segments."
|
||||
description: "Has same shape as data, except for dimension 0 which\nhas size `k`, the number of segments."
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
@ -6213,7 +6245,7 @@ op {
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
description: "Has same shape as data, except for dimension_0 which\nhas size `k`, the number of segments."
|
||||
description: "Has same shape as data, except for dimension 0 which\nhas size `k`, the number of segments."
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
@ -6257,7 +6289,7 @@ op {
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
description: "Has same shape as data, except for dimension_0 which\nhas size `k`, the number of segments."
|
||||
description: "Has same shape as data, except for dimension 0 which\nhas size `k`, the number of segments."
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
@ -6301,7 +6333,7 @@ op {
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
description: "Has same shape as data, except for dimension_0 which\nhas size `k`, the number of segments."
|
||||
description: "Has same shape as data, except for dimension 0 which\nhas size `k`, the number of segments."
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
@ -7044,7 +7076,7 @@ op {
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
description: "Has same shape as data, except for dimension_0 which\nhas size `k`, the number of segments."
|
||||
description: "Has same shape as data, except for dimension 0 which\nhas size `k`, the number of segments."
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
@ -7058,7 +7090,7 @@ op {
|
||||
}
|
||||
}
|
||||
summary: "Computes the mean along sparse segments of a tensor."
|
||||
description: "Read [the section on\nSegmentation](../../api_docs/python/math_ops.md#segmentation) for an explanation\nof segments.\n\nLike `SegmentMean`, but `segment_ids` can have rank less than `data`\'s first\ndimension, selecting a subset of dimension_0, specified by `indices`."
|
||||
description: "Read [the section on\nSegmentation](../../api_docs/python/math_ops.md#segmentation) for an explanation\nof segments.\n\nLike `SegmentMean`, but `segment_ids` can have rank less than `data`\'s first\ndimension, selecting a subset of dimension 0, specified by `indices`."
|
||||
}
|
||||
op {
|
||||
name: "SparseSegmentMeanGrad"
|
||||
@ -7079,7 +7111,7 @@ op {
|
||||
}
|
||||
input_arg {
|
||||
name: "output_dim0"
|
||||
description: "dimension_0 of \"data\" passed to SparseSegmentMean op."
|
||||
description: "dimension 0 of \"data\" passed to SparseSegmentMean op."
|
||||
type: DT_INT32
|
||||
}
|
||||
output_arg {
|
||||
@ -7097,7 +7129,7 @@ op {
|
||||
}
|
||||
}
|
||||
summary: "Computes gradients for SparseSegmentMean."
|
||||
description: "Returns tensor \"output\" with same shape as grad, except for dimension_0 whose\nvalue is output_dim0."
|
||||
description: "Returns tensor \"output\" with same shape as grad, except for dimension 0 whose\nvalue is output_dim0."
|
||||
}
|
||||
op {
|
||||
name: "SparseSegmentSum"
|
||||
@ -7117,7 +7149,7 @@ op {
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
description: "Has same shape as data, except for dimension_0 which\nhas size `k`, the number of segments."
|
||||
description: "Has same shape as data, except for dimension 0 which\nhas size `k`, the number of segments."
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
@ -7136,7 +7168,7 @@ op {
|
||||
}
|
||||
}
|
||||
summary: "Computes the sum along sparse segments of a tensor."
|
||||
description: "Read [the section on\nSegmentation](../../api_docs/python/math_ops.md#segmentation) for an explanation\nof segments.\n\nLike `SegmentSum`, but `segment_ids` can have rank less than `data`\'s first\ndimension, selecting a subset of dimension_0, specified by `indices`.\n\nFor example:\n\n```prettyprint\nc = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]])\n\n# Select two rows, one segment.\ntf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 0]))\n ==> [[0 0 0 0]]\n\n# Select two rows, two segment.\ntf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 1]))\n ==> [[ 1 2 3 4]\n [-1 -2 -3 -4]]\n\n# Select all rows, two segments.\ntf.sparse_segment_sum(c, tf.constant([0, 1, 2]), tf.constant([0, 0, 1]))\n ==> [[0 0 0 0]\n [5 6 7 8]]\n\n# Which is equivalent to:\ntf.segment_sum(c, tf.constant([0, 0, 1]))\n```"
|
||||
description: "Read [the section on\nSegmentation](../../api_docs/python/math_ops.md#segmentation) for an explanation\nof segments.\n\nLike `SegmentSum`, but `segment_ids` can have rank less than `data`\'s first\ndimension, selecting a subset of dimension 0, specified by `indices`.\n\nFor example:\n\n```prettyprint\nc = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]])\n\n# Select two rows, one segment.\ntf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 0]))\n ==> [[0 0 0 0]]\n\n# Select two rows, two segment.\ntf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 1]))\n ==> [[ 1 2 3 4]\n [-1 -2 -3 -4]]\n\n# Select all rows, two segments.\ntf.sparse_segment_sum(c, tf.constant([0, 1, 2]), tf.constant([0, 0, 1]))\n ==> [[0 0 0 0]\n [5 6 7 8]]\n\n# Which is equivalent to:\ntf.segment_sum(c, tf.constant([0, 0, 1]))\n```"
|
||||
}
|
||||
op {
|
||||
name: "SparseToDense"
|
||||
@ -7294,6 +7326,84 @@ op {
|
||||
summary: "Removes dimensions of size 1 from the shape of a tensor."
|
||||
description: "Given a tensor `input`, this operation returns a tensor of the same type with\nall dimensions of size 1 removed. If you don\'t want to remove all size 1\ndimensions, you can remove specific size 1 dimensions by specifying\n`squeeze_dims`.\n\nFor example:\n\n```prettyprint\n# \'t\' is a tensor of shape [1, 2, 1, 3, 1, 1]\nshape(squeeze(t)) ==> [2, 3]\n```\n\nOr, to remove specific size 1 dimensions:\n\n```prettyprint\n# \'t\' is a tensor of shape [1, 2, 1, 3, 1, 1]\nshape(squeeze(t, [2, 4])) ==> [1, 2, 3, 1]\n```"
|
||||
}
|
||||
op {
|
||||
name: "Stack"
|
||||
output_arg {
|
||||
name: "handle"
|
||||
description: "The handle to the stack."
|
||||
type: DT_STRING
|
||||
is_ref: true
|
||||
}
|
||||
attr {
|
||||
name: "elem_type"
|
||||
type: "type"
|
||||
description: "The type of the elements on the stack."
|
||||
}
|
||||
attr {
|
||||
name: "stack_name"
|
||||
type: "string"
|
||||
default_value {
|
||||
s: ""
|
||||
}
|
||||
description: "Overrides the name used for the temporary stack resource. Default\nvalue is the name of the \'Stack\' op (which is guaranteed unique)."
|
||||
}
|
||||
summary: "A stack that produces elements in first-in last-out order."
|
||||
is_stateful: true
|
||||
}
|
||||
op {
|
||||
name: "StackClose"
|
||||
input_arg {
|
||||
name: "handle"
|
||||
description: "The handle to a stack."
|
||||
type: DT_STRING
|
||||
is_ref: true
|
||||
}
|
||||
summary: "Delete the stack from its resource container."
|
||||
}
|
||||
op {
|
||||
name: "StackPop"
|
||||
input_arg {
|
||||
name: "handle"
|
||||
description: "The handle to a stack."
|
||||
type: DT_STRING
|
||||
is_ref: true
|
||||
}
|
||||
output_arg {
|
||||
name: "elem"
|
||||
description: "The tensor that is popped from the top of the stack."
|
||||
type_attr: "elem_type"
|
||||
}
|
||||
attr {
|
||||
name: "elem_type"
|
||||
type: "type"
|
||||
description: "The type of the elem that is popped."
|
||||
}
|
||||
summary: "Pop the element at the top of the stack."
|
||||
}
|
||||
op {
|
||||
name: "StackPush"
|
||||
input_arg {
|
||||
name: "handle"
|
||||
description: "The handle to a stack."
|
||||
type: DT_STRING
|
||||
is_ref: true
|
||||
}
|
||||
input_arg {
|
||||
name: "elem"
|
||||
description: "The tensor to be pushed onto the stack."
|
||||
type_attr: "T"
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
description: "The same tensor as the input \'elem\'."
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
}
|
||||
summary: "Push an element onto the stack."
|
||||
}
|
||||
op {
|
||||
name: "StopGradient"
|
||||
input_arg {
|
||||
@ -7319,7 +7429,7 @@ op {
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
description: "A Tensor of the same shape as the input string_tensor."
|
||||
description: "A Tensor of the same shape as the input `string_tensor`."
|
||||
type: DT_INT64
|
||||
}
|
||||
attr {
|
||||
@ -7340,7 +7450,7 @@ op {
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
description: "A Tensor of the same shape as the input string_tensor."
|
||||
description: "A Tensor of the same shape as the input `string_tensor`."
|
||||
type_attr: "out_type"
|
||||
}
|
||||
attr {
|
||||
@ -7942,7 +8052,7 @@ op {
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
description: "Has same shape as data, except for dimension_0 which\nhas size `num_segments`."
|
||||
description: "Has same shape as data, except for dimension 0 which\nhas size `num_segments`."
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
|
@ -26,11 +26,12 @@ REGISTER_OP("DecodeRaw")
|
||||
Reinterpret the bytes of a string as a vector of numbers.
|
||||
|
||||
bytes: All the elements must have the same length.
|
||||
little_endian: Whether the input bytes are in little-endian order.
|
||||
Ignored for out_types that are stored in a single byte like uint8.
|
||||
output: A Tensor with one more dimension than the input bytes. The
|
||||
little_endian: Whether the input `bytes` are in little-endian order.
|
||||
Ignored for `out_type` values that are stored in a single byte like
|
||||
`uint8`.
|
||||
output: A Tensor with one more dimension than the input `bytes`. The
|
||||
added dimension will have size equal to the length of the elements
|
||||
of bytes divided by the number of bytes to represent out_type.
|
||||
of `bytes` divided by the number of bytes to represent `out_type`.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("ParseExample")
|
||||
@ -113,7 +114,7 @@ Converts each string in the input Tensor to the specified numeric type.
|
||||
results in a rounded value.)
|
||||
|
||||
out_type: The numeric type to interpret each string in string_tensor as.
|
||||
output: A Tensor of the same shape as the input string_tensor.
|
||||
output: A Tensor of the same shape as the input `string_tensor`.
|
||||
)doc");
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -30,7 +30,7 @@ process.
|
||||
Note that the hash function may change from time to time.
|
||||
|
||||
num_buckets: The number of buckets.
|
||||
output: A Tensor of the same shape as the input string_tensor.
|
||||
output: A Tensor of the same shape as the input `string_tensor`.
|
||||
)doc");
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -38,6 +38,7 @@ def tf_proto_library(name, srcs = [], has_services = False,
|
||||
|
||||
py_proto_library(name=name + "_py",
|
||||
srcs=srcs + tf_deps(deps, "_proto_srcs"),
|
||||
srcs_version="PY2AND3",
|
||||
deps=deps,
|
||||
py_libs = ["//google/protobuf:protobuf_python"],
|
||||
testonly=testonly,
|
||||
@ -46,6 +47,7 @@ def tf_proto_library(name, srcs = [], has_services = False,
|
||||
def tf_proto_library_py(name, srcs=[], deps=[], visibility=[], testonly=0):
|
||||
py_proto_library(name = name + "_py",
|
||||
srcs = srcs,
|
||||
srcs_version = "PY2AND3",
|
||||
deps = deps,
|
||||
visibility = visibility,
|
||||
testonly = testonly)
|
||||
|
@ -1,33 +1,18 @@
|
||||
// Copyright (c) 2008, Google Inc.
|
||||
// All rights reserved.
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
// modification, are permitted provided that the following conditions are
|
||||
// met:
|
||||
//
|
||||
// * Redistributions of source code must retain the above copyright
|
||||
// notice, this list of conditions and the following disclaimer.
|
||||
// * Redistributions in binary form must reproduce the above
|
||||
// copyright notice, this list of conditions and the following disclaimer
|
||||
// in the documentation and/or other materials provided with the
|
||||
// distribution.
|
||||
// * Neither the name of Google Inc. nor the names of its
|
||||
// contributors may be used to endorse or promote products derived from
|
||||
// this software without specific prior written permission.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
// ---
|
||||
//
|
||||
/* Copyright 2015 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This header file contains the macro definitions for thread safety
|
||||
// annotations that allow the developers to document the locking policies
|
||||
// of their multi-threaded code. The annotations can also help program
|
||||
|
@ -107,7 +107,7 @@ class Env {
|
||||
/// Deletes the specified directory.
|
||||
virtual Status DeleteDir(const string& dirname) = 0;
|
||||
|
||||
/// Stores the size of fname in *file_size.
|
||||
/// Stores the size of `fname` in `*file_size`.
|
||||
virtual Status GetFileSize(const string& fname, uint64* file_size) = 0;
|
||||
|
||||
/// \brief Renames file src to target. If target already exists, it will be
|
||||
@ -146,18 +146,18 @@ class RandomAccessFile {
|
||||
RandomAccessFile() {}
|
||||
virtual ~RandomAccessFile();
|
||||
|
||||
/// \brief Reads up to "n" bytes from the file starting at "offset".
|
||||
/// \brief Reads up to `n` bytes from the file starting at `offset`.
|
||||
///
|
||||
/// "scratch[0..n-1]" may be written by this routine. Sets "*result"
|
||||
/// to the data that was read (including if fewer than "n" bytes were
|
||||
/// successfully read). May set "*result" to point at data in
|
||||
/// "scratch[0..n-1]", so "scratch[0..n-1]" must be live when
|
||||
/// "*result" is used.
|
||||
/// `scratch[0..n-1]` may be written by this routine. Sets `*result`
|
||||
/// to the data that was read (including if fewer than `n` bytes were
|
||||
/// successfully read). May set `*result` to point at data in
|
||||
/// `scratch[0..n-1]`, so `scratch[0..n-1]` must be live when
|
||||
/// `*result` is used.
|
||||
///
|
||||
/// On OK returned status: "n" bytes have been stored in "*result".
|
||||
/// On non-OK returned status: [0..n] bytes have been stored in "*result".
|
||||
/// On OK returned status: `n` bytes have been stored in `*result`.
|
||||
/// On non-OK returned status: `[0..n]` bytes have been stored in `*result`.
|
||||
///
|
||||
/// Returns `OUT_OF_RANGE` if fewer than n bytes were stored in "*result"
|
||||
/// Returns `OUT_OF_RANGE` if fewer than n bytes were stored in `*result`
|
||||
/// because of EOF.
|
||||
///
|
||||
/// Safe for concurrent use by multiple threads.
|
||||
@ -263,16 +263,16 @@ struct ThreadOptions {
|
||||
size_t guard_size = 0; // 0: use system default value
|
||||
};
|
||||
|
||||
/// A utility routine: reads contents of named file into *data
|
||||
/// A utility routine: reads contents of named file into `*data`
|
||||
Status ReadFileToString(Env* env, const string& fname, string* data);
|
||||
|
||||
/// A utility routine: write contents of "data" to file named "fname"
|
||||
/// A utility routine: write contents of `data` to file named `fname`
|
||||
/// (overwriting existing contents, if any).
|
||||
Status WriteStringToFile(Env* env, const string& fname,
|
||||
const StringPiece& data);
|
||||
|
||||
/// Reads contents of named file and parse as binary encoded proto data
|
||||
/// and store into *proto.
|
||||
/// and store into `*proto`.
|
||||
Status ReadBinaryProto(Env* env, const string& fname,
|
||||
::tensorflow::protobuf::MessageLite* proto);
|
||||
|
||||
|
@ -60,6 +60,15 @@ inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count,
|
||||
return config;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __host__ inline T ldg(const T* address) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
|
||||
return __ldg(address);
|
||||
#else
|
||||
return *address;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
|
@ -90,7 +90,7 @@ string EventsWriter::FileName() {
|
||||
return filename_;
|
||||
}
|
||||
|
||||
void EventsWriter::WriteSerializedEvent(const string& event_str) {
|
||||
void EventsWriter::WriteSerializedEvent(StringPiece event_str) {
|
||||
if (recordio_writer_.get() == NULL) {
|
||||
if (!Init()) {
|
||||
LOG(ERROR) << "Write failed because file could not be opened.";
|
||||
|
@ -61,8 +61,8 @@ class EventsWriter {
|
||||
|
||||
// Append "event_str", a serialized Event, to the file.
|
||||
// Note that this function does NOT check that de-serializing event_str
|
||||
// results in a valid Event proto.
|
||||
void WriteSerializedEvent(const string& event_str);
|
||||
// results in a valid Event proto. The tensorflow:: bit makes SWIG happy.
|
||||
void WriteSerializedEvent(tensorflow::StringPiece event_str);
|
||||
|
||||
// EventWriter automatically flushes and closes on destruction, but
|
||||
// these two methods are provided for users who want to write to disk sooner
|
||||
|
@ -14,6 +14,7 @@ cc_library(
|
||||
copts = [
|
||||
"-std=c++11",
|
||||
"-mfpu=neon",
|
||||
"-O2",
|
||||
],
|
||||
linkopts = ["-llog -landroid -lm -ljnigraphics"],
|
||||
tags = [
|
||||
|
@ -36,13 +36,9 @@ Then, after editing your WORKSPACE file, you must build the APK. Run this from
|
||||
your workspace root:
|
||||
|
||||
```bash
|
||||
$ bazel build //tensorflow/examples/android:tensorflow_demo -c opt --copt=-mfpu=neon
|
||||
$ bazel build //tensorflow/examples/android:tensorflow_demo
|
||||
```
|
||||
|
||||
Note that `-c opt` is currently required; if not set, an assert (for an
|
||||
otherwise non-problematic issue) in Eigen will halt the application during
|
||||
execution. This issue will be corrected in an upcoming release.
|
||||
|
||||
If adb debugging is enabled on your Android 5.0 or later device, you may then
|
||||
use the following command from your workspace root to install the APK once
|
||||
built:
|
||||
@ -55,7 +51,7 @@ Alternatively, a streamlined means of building, installing and running in one
|
||||
command is:
|
||||
|
||||
```bash
|
||||
$ bazel mobile-install //tensorflow/examples/android:tensorflow_demo -c opt --start_app --copt=-mfpu=neon
|
||||
$ bazel mobile-install //tensorflow/examples/android:tensorflow_demo --start_app
|
||||
```
|
||||
|
||||
If camera permission errors are encountered (possible on Android Marshmallow or
|
||||
|
@ -231,7 +231,7 @@ public class CameraConnectionFragment extends Fragment {
|
||||
private static Size chooseOptimalSize(
|
||||
final Size[] choices, final int width, final int height, final Size aspectRatio) {
|
||||
// Collect the supported resolutions that are at least as big as the preview Surface
|
||||
final List<Size> bigEnough = new ArrayList<>();
|
||||
final List<Size> bigEnough = new ArrayList<Size>();
|
||||
for (final Size option : choices) {
|
||||
if (option.getHeight() >= MINIMUM_PREVIEW_SIZE && option.getWidth() >= MINIMUM_PREVIEW_SIZE) {
|
||||
LOGGER.i("Adding size: " + option.getWidth() + "x" + option.getHeight());
|
||||
|
@ -1,3 +1,18 @@
|
||||
/* Copyright 2015 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.demo;
|
||||
|
||||
import android.graphics.Bitmap;
|
||||
|
@ -1,3 +1,18 @@
|
||||
/* Copyright 2015 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.demo;
|
||||
|
||||
import android.content.Context;
|
||||
|
@ -1,3 +1,18 @@
|
||||
/* Copyright 2015 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.demo;
|
||||
|
||||
import android.content.res.AssetManager;
|
||||
|
@ -1,3 +1,18 @@
|
||||
/* Copyright 2015 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.demo;
|
||||
|
||||
import android.content.res.AssetManager;
|
||||
@ -9,7 +24,6 @@ import android.media.Image;
|
||||
import android.media.Image.Plane;
|
||||
import android.media.ImageReader;
|
||||
import android.media.ImageReader.OnImageAvailableListener;
|
||||
|
||||
import junit.framework.Assert;
|
||||
|
||||
import org.tensorflow.demo.env.ImageUtils;
|
||||
@ -54,6 +68,44 @@ public class TensorflowImageListener implements OnImageAvailableListener {
|
||||
this.scoreView = scoreView;
|
||||
}
|
||||
|
||||
private void readPlanesToYuvBuffer(final Plane[] planes, final byte[] yuvBytes) {
|
||||
int position = 0;
|
||||
|
||||
// Copy the bytes from the Image into a buffer for easier conversion to RGB.
|
||||
// TODO(andrewharp): Modify native code to accept multiple buffers so that
|
||||
// only one pass is necessary during conversion to RGB.
|
||||
final Plane yPlane = planes[0];
|
||||
final ByteBuffer yBuffer = yPlane.getBuffer();
|
||||
final int yRowStride = yPlane.getRowStride();
|
||||
|
||||
// Read the y (luminance buffer).
|
||||
for (int row = 0; row < previewHeight; ++row) {
|
||||
yBuffer.position(yRowStride * row);
|
||||
|
||||
// Pixel stride is guaranteed to be 1 so we can
|
||||
// just do a copy operation.
|
||||
yBuffer.get(yuvBytes, position, previewWidth);
|
||||
position += previewWidth;
|
||||
}
|
||||
|
||||
// Interleave the u and v buffers.
|
||||
final ByteBuffer uBuffer = planes[1].getBuffer();
|
||||
final ByteBuffer vBuffer = planes[2].getBuffer();
|
||||
final int uvPixelStride = planes[1].getPixelStride();
|
||||
final int uvWidth = previewWidth / 2;
|
||||
final int uvHeight = previewHeight / 2;
|
||||
Assert.assertEquals(
|
||||
planes[1].getRowStride(), planes[2].getRowStride());
|
||||
for (int y = 0; y < uvHeight; ++y) {
|
||||
int readPos = planes[1].getRowStride() * y;
|
||||
for (int x = 0; x < uvWidth; ++x) {
|
||||
yuvBytes[position++] = vBuffer.get(readPos);
|
||||
yuvBytes[position++] = uBuffer.get(readPos);
|
||||
readPos += uvPixelStride;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void drawResizedBitmap(final Bitmap src, final Bitmap dst) {
|
||||
Assert.assertEquals(dst.getWidth(), dst.getHeight());
|
||||
final float minDim = Math.min(src.getWidth(), src.getHeight());
|
||||
@ -91,31 +143,17 @@ public class TensorflowImageListener implements OnImageAvailableListener {
|
||||
|
||||
// Initialize the storage bitmaps once when the resolution is known.
|
||||
if (previewWidth != image.getWidth() || previewHeight != image.getHeight()) {
|
||||
LOGGER.i("Initializing at size %dx%d", previewWidth, previewHeight);
|
||||
previewWidth = image.getWidth();
|
||||
previewHeight = image.getHeight();
|
||||
|
||||
LOGGER.i("Initializing at size %dx%d", previewWidth, previewHeight);
|
||||
rgbBytes = new int[previewWidth * previewHeight];
|
||||
yuvBytes = new byte[ImageUtils.getYUVByteSize(previewWidth, previewHeight)];
|
||||
rgbFrameBitmap = Bitmap.createBitmap(previewWidth, previewHeight, Config.ARGB_8888);
|
||||
croppedBitmap = Bitmap.createBitmap(INPUT_SIZE, INPUT_SIZE, Config.ARGB_8888);
|
||||
}
|
||||
|
||||
final Plane[] planes = image.getPlanes();
|
||||
int position = 0;
|
||||
|
||||
// Copy the bytes from the Image into a buffer for easier conversion to RGB.
|
||||
// TODO(andrewharp): It may not be correct to do it this way.
|
||||
final int[] planeOrder = {0, 2};
|
||||
for (int i = 0; i < planeOrder.length; ++i) {
|
||||
final Plane plane = planes[planeOrder[i]];
|
||||
final ByteBuffer buffer = plane.getBuffer();
|
||||
|
||||
buffer.rewind();
|
||||
final int readAmount = buffer.remaining();
|
||||
|
||||
buffer.get(yuvBytes, position, readAmount);
|
||||
position += readAmount;
|
||||
}
|
||||
readPlanesToYuvBuffer(image.getPlanes(), yuvBytes);
|
||||
|
||||
image.close();
|
||||
|
||||
|
@ -1,3 +1,18 @@
|
||||
/* Copyright 2015 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.demo.env;
|
||||
|
||||
import android.graphics.Bitmap;
|
||||
|
@ -1,3 +1,18 @@
|
||||
/* Copyright 2015 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
package org.tensorflow.demo.env;
|
||||
|
||||
import android.util.Log;
|
||||
|
@ -20,7 +20,7 @@ limitations under the License.
|
||||
// It's designed to have as few dependencies and be as clear as possible, so
|
||||
// it's more verbose than it could be in production code. In particular, using
|
||||
// auto for the types of a lot of the returned values from TensorFlow calls can
|
||||
// remove a lot of boilerplate, but I find them explicit types useful in sample
|
||||
// remove a lot of boilerplate, but I find the explicit types useful in sample
|
||||
// code to make it simple to look up the classes involved.
|
||||
//
|
||||
// To use it, compile and then run in a working directory with the
|
||||
|
@ -27,7 +27,7 @@ All Env implementations are safe for concurrent access from multiple threads wit
|
||||
* [`virtual Status tensorflow::Env::DeleteDir(const string &dirname)=0`](#virtual_Status_tensorflow_Env_DeleteDir)
|
||||
* Deletes the specified directory.
|
||||
* [`virtual Status tensorflow::Env::GetFileSize(const string &fname, uint64 *file_size)=0`](#virtual_Status_tensorflow_Env_GetFileSize)
|
||||
* Stores the size of fname in *file_size.
|
||||
* Stores the size of `fname` in `*file_size`.
|
||||
* [`virtual Status tensorflow::Env::RenameFile(const string &src, const string &target)=0`](#virtual_Status_tensorflow_Env_RenameFile)
|
||||
* Renames file src to target. If target already exists, it will be replaced.
|
||||
* [`virtual uint64 tensorflow::Env::NowMicros()=0`](#virtual_uint64_tensorflow_Env_NowMicros)
|
||||
@ -109,7 +109,7 @@ Deletes the specified directory.
|
||||
|
||||
#### `virtual Status tensorflow::Env::GetFileSize(const string &fname, uint64 *file_size)=0` {#virtual_Status_tensorflow_Env_GetFileSize}
|
||||
|
||||
Stores the size of fname in *file_size.
|
||||
Stores the size of `fname` in `*file_size`.
|
||||
|
||||
|
||||
|
||||
|
@ -28,7 +28,7 @@ May be useful to clients who wish to override just part of the functionality of
|
||||
* [`Status tensorflow::EnvWrapper::DeleteDir(const string &d) override`](#Status_tensorflow_EnvWrapper_DeleteDir)
|
||||
* Deletes the specified directory.
|
||||
* [`Status tensorflow::EnvWrapper::GetFileSize(const string &f, uint64 *s) override`](#Status_tensorflow_EnvWrapper_GetFileSize)
|
||||
* Stores the size of fname in *file_size.
|
||||
* Stores the size of `fname` in `*file_size`.
|
||||
* [`Status tensorflow::EnvWrapper::RenameFile(const string &s, const string &t) override`](#Status_tensorflow_EnvWrapper_RenameFile)
|
||||
* Renames file src to target. If target already exists, it will be replaced.
|
||||
* [`uint64 tensorflow::EnvWrapper::NowMicros() override`](#uint64_tensorflow_EnvWrapper_NowMicros)
|
||||
@ -114,7 +114,7 @@ Deletes the specified directory.
|
||||
|
||||
#### `Status tensorflow::EnvWrapper::GetFileSize(const string &f, uint64 *s) override` {#Status_tensorflow_EnvWrapper_GetFileSize}
|
||||
|
||||
Stores the size of fname in *file_size.
|
||||
Stores the size of `fname` in `*file_size`.
|
||||
|
||||
|
||||
|
||||
|
@ -9,7 +9,7 @@ A file abstraction for randomly reading the contents of a file.
|
||||
* [`tensorflow::RandomAccessFile::RandomAccessFile()`](#tensorflow_RandomAccessFile_RandomAccessFile)
|
||||
* [`virtual tensorflow::RandomAccessFile::~RandomAccessFile()`](#virtual_tensorflow_RandomAccessFile_RandomAccessFile)
|
||||
* [`virtual Status tensorflow::RandomAccessFile::Read(uint64 offset, size_t n, StringPiece *result, char *scratch) const =0`](#virtual_Status_tensorflow_RandomAccessFile_Read)
|
||||
* Reads up to "n" bytes from the file starting at "offset".
|
||||
* Reads up to `n` bytes from the file starting at `offset`.
|
||||
|
||||
##Member Details
|
||||
|
||||
@ -27,12 +27,12 @@ A file abstraction for randomly reading the contents of a file.
|
||||
|
||||
#### `virtual Status tensorflow::RandomAccessFile::Read(uint64 offset, size_t n, StringPiece *result, char *scratch) const =0` {#virtual_Status_tensorflow_RandomAccessFile_Read}
|
||||
|
||||
Reads up to "n" bytes from the file starting at "offset".
|
||||
Reads up to `n` bytes from the file starting at `offset`.
|
||||
|
||||
"scratch[0..n-1]" may be written by this routine. Sets "*result" to the data that was read (including if fewer than "n" bytes were successfully read). May set "*result" to point at data in "scratch[0..n-1]", so "scratch[0..n-1]" must be live when "*result" is used.
|
||||
`scratch[0..n-1]` may be written by this routine. Sets `*result` to the data that was read (including if fewer than `n` bytes were successfully read). May set `*result` to point at data in `scratch[0..n-1]`, so `scratch[0..n-1]` must be live when `*result` is used.
|
||||
|
||||
On OK returned status: "n" bytes have been stored in "*result". On non-OK returned status: [0..n] bytes have been stored in "*result".
|
||||
On OK returned status: `n` bytes have been stored in `*result`. On non-OK returned status: `[0..n]` bytes have been stored in `*result`.
|
||||
|
||||
Returns `OUT_OF_RANGE` if fewer than n bytes were stored in "*result" because of EOF.
|
||||
Returns `OUT_OF_RANGE` if fewer than n bytes were stored in `*result` because of EOF.
|
||||
|
||||
Safe for concurrent use by multiple threads.
|
||||
|
@ -4,7 +4,7 @@ TensorFlow's public C++ API includes only the API for executing graphs, as of
|
||||
version 0.5. To control the execution of a graph from C++:
|
||||
|
||||
1. Build the computation graph using the [Python API](../python/).
|
||||
1. Use [tf.train.write_graph()](../python/train.md#write_graph) to
|
||||
1. Use [`tf.train.write_graph()`](../python/train.md#write_graph) to
|
||||
write the graph to a file.
|
||||
1. Load the graph using the C++ Session API. For example:
|
||||
|
||||
|
Before ![]() (image error) Size: 46 KiB |
Before ![]() (image error) Size: 50 KiB |
Before ![]() (image error) Size: 40 KiB |
Before ![]() (image error) Size: 52 KiB |
Before ![]() (image error) Size: 51 KiB |
Before ![]() (image error) Size: 49 KiB |
Before ![]() (image error) Size: 50 KiB |
Before ![]() (image error) Size: 51 KiB |
Before ![]() (image error) Size: 45 KiB |
Before ![]() (image error) Size: 48 KiB |
Before ![]() (image error) Size: 43 KiB |
Before ![]() (image error) Size: 36 KiB |
@ -32,7 +32,7 @@ results in a rounded value.)
|
||||
##### Returns:
|
||||
|
||||
A `Tensor` of type `out_type`.
|
||||
A Tensor of the same shape as the input string_tensor.
|
||||
A Tensor of the same shape as the input `string_tensor`.
|
||||
|
||||
|
||||
- - -
|
||||
@ -817,7 +817,7 @@ tf.transpose(x) ==> [[1 4]
|
||||
[3 6]]
|
||||
|
||||
# Equivalently
|
||||
tf.transpose(x perm=[0, 1]) ==> [[1 4]
|
||||
tf.transpose(x, perm=[1, 0]) ==> [[1 4]
|
||||
[2 5]
|
||||
[3 6]]
|
||||
|
||||
|
@ -178,6 +178,7 @@ Calling this method frees all resources associated with the session.
|
||||
The graph that was launched in this session.
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.Session.as_default()` {#Session.as_default}
|
||||
@ -354,6 +355,7 @@ discover information about the op.
|
||||
|
||||
The `Operation` that failed, or None.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.OpError.node_def` {#OpError.node_def}
|
||||
@ -361,6 +363,7 @@ discover information about the op.
|
||||
The `NodeDef` proto representing the op that failed.
|
||||
|
||||
|
||||
|
||||
#### Other Methods
|
||||
- - -
|
||||
|
||||
@ -383,6 +386,7 @@ Creates a new `OpError` indicating that a particular op failed.
|
||||
|
||||
The integer error code that describes the error.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.OpError.message` {#OpError.message}
|
||||
@ -390,6 +394,7 @@ The integer error code that describes the error.
|
||||
The error message that describes the error.
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
### `class tf.errors.CancelledError` {#CancelledError}
|
||||
|
@ -504,7 +504,7 @@ that are not a number (NaN) or infinity (Inf). Otherwise, passes `tensor` as-is.
|
||||
|
||||
### `tf.add_check_numerics_ops()` {#add_check_numerics_ops}
|
||||
|
||||
Connect a check_numerics to every floating point tensor.
|
||||
Connect a `check_numerics` to every floating point tensor.
|
||||
|
||||
`check_numerics` operations themselves are added for each `float` or `double`
|
||||
tensor in the graph. For all ops in the graph, the `check_numerics` op for
|
||||
|
@ -119,7 +119,7 @@ This method is thread-safe.
|
||||
##### Raises:
|
||||
|
||||
|
||||
* <b>`ValueError`</b>: If the graph_def would be too large.
|
||||
* <b>`ValueError`</b>: If the `graph_def` would be too large.
|
||||
|
||||
|
||||
- - -
|
||||
@ -141,6 +141,7 @@ when using a [`QueueRunner`](../../api_docs/python/train.md#QueueRunner).
|
||||
True if this graph has been finalized.
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.Graph.control_dependencies(control_inputs)` {#Graph.control_dependencies}
|
||||
@ -511,6 +512,7 @@ Returns the default device.
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.Graph.unique_name(name)` {#Graph.unique_name}
|
||||
@ -545,6 +547,7 @@ TensorBoard.
|
||||
Returns a version number that increases as ops are added to the graph.
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.Graph.create_op(op_type, inputs, dtypes, input_types=None, name=None, attrs=None, op_def=None, compute_shapes=True)` {#Graph.create_op}
|
||||
@ -658,18 +661,21 @@ be executed by passing it to
|
||||
|
||||
The full name of this operation.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.Operation.type` {#Operation.type}
|
||||
|
||||
The type of the op (e.g. `"MatMul"`).
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.Operation.inputs` {#Operation.inputs}
|
||||
|
||||
The list of `Tensor` objects representing the data inputs of this op.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.Operation.control_inputs` {#Operation.control_inputs}
|
||||
@ -686,12 +692,14 @@ in the correct order.
|
||||
|
||||
A list of `Operation` objects.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.Operation.outputs` {#Operation.outputs}
|
||||
|
||||
The list of `Tensor` objects representing the outputs of this op.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.Operation.device` {#Operation.device}
|
||||
@ -703,6 +711,7 @@ The name of the device to which this op has been assigned, if any.
|
||||
The string name of the device to which this op has been
|
||||
assigned, or None if it has not been assigned to a device.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.Operation.graph` {#Operation.graph}
|
||||
@ -710,6 +719,7 @@ The name of the device to which this op has been assigned, if any.
|
||||
The `Graph` that contains this operation.
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.Operation.run(feed_dict=None, session=None)` {#Operation.run}
|
||||
@ -762,6 +772,7 @@ Returns the value of the attr of this op with the given `name`.
|
||||
Returns the call stack from when this operation was constructed.
|
||||
|
||||
|
||||
|
||||
#### Other Methods
|
||||
- - -
|
||||
|
||||
@ -803,11 +814,11 @@ regular expression:
|
||||
|
||||
|
||||
* <b>`TypeError`</b>: if control inputs are not Operations or Tensors,
|
||||
or if node_def is not a `NodeDef`,
|
||||
or if g is not a `Graph`,
|
||||
or if inputs are not tensors,
|
||||
or if inputs and input_types are incompatible.
|
||||
* <b>`ValueError`</b>: if the node_def name is not valid.
|
||||
or if `node_def` is not a `NodeDef`,
|
||||
or if `g` is not a `Graph`,
|
||||
or if `inputs` are not tensors,
|
||||
or if `inputs` and `input_types` are incompatible.
|
||||
* <b>`ValueError`</b>: if the `node_def` name is not valid.
|
||||
|
||||
|
||||
- - -
|
||||
@ -822,6 +833,7 @@ Returns a serialized `NodeDef` representation of this operation.
|
||||
[`NodeDef`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/graph.proto)
|
||||
protocol buffer.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.Operation.op_def` {#Operation.op_def}
|
||||
@ -834,6 +846,7 @@ Returns the `OpDef` proto that represents the type of this op.
|
||||
[`OpDef`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/op_def.proto)
|
||||
protocol buffer.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.Operation.values()` {#Operation.values}
|
||||
@ -889,30 +902,35 @@ result = sess.run(e)
|
||||
|
||||
The `DType` of elements in this tensor.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.Tensor.name` {#Tensor.name}
|
||||
|
||||
The string name of this tensor.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.Tensor.value_index` {#Tensor.value_index}
|
||||
|
||||
The index of this tensor in the outputs of its `Operation`.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.Tensor.graph` {#Tensor.graph}
|
||||
|
||||
The `Graph` that contains this tensor.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.Tensor.op` {#Tensor.op}
|
||||
|
||||
The `Operation` that produces this tensor as an output.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.Tensor.consumers()` {#Tensor.consumers}
|
||||
@ -1070,6 +1088,7 @@ The name of the device on which this tensor will be produced, or None.
|
||||
|
||||
|
||||
|
||||
|
||||
## Tensor types
|
||||
|
||||
- - -
|
||||
@ -1136,30 +1155,35 @@ DType(T).as_ref.is_compatible_with(DType(T).as_ref) == True
|
||||
|
||||
Returns the string name for this `DType`.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.DType.base_dtype` {#DType.base_dtype}
|
||||
|
||||
Returns a non-reference `DType` based on this `DType`.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.DType.is_ref_dtype` {#DType.is_ref_dtype}
|
||||
|
||||
Returns `True` if this `DType` represents a reference type.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.DType.as_ref` {#DType.as_ref}
|
||||
|
||||
Returns a reference `DType` based on this `DType`.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.DType.is_integer` {#DType.is_integer}
|
||||
|
||||
Returns whether this is a (non-quantized) integer type.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.DType.is_quantized` {#DType.is_quantized}
|
||||
@ -1167,12 +1191,14 @@ Returns whether this is a (non-quantized) integer type.
|
||||
Returns whether this is a quantized data type.
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.DType.as_numpy_dtype` {#DType.as_numpy_dtype}
|
||||
|
||||
Returns a `numpy.dtype` based on this `DType`.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.DType.as_datatype_enum` {#DType.as_datatype_enum}
|
||||
@ -1180,6 +1206,7 @@ Returns a `numpy.dtype` based on this `DType`.
|
||||
Returns a `types_pb2.DataType` enum value based on this `DType`.
|
||||
|
||||
|
||||
|
||||
#### Other Methods
|
||||
- - -
|
||||
|
||||
@ -1188,8 +1215,8 @@ Returns a `types_pb2.DataType` enum value based on this `DType`.
|
||||
Creates a new `DataType`.
|
||||
|
||||
NOTE(mrry): In normal circumstances, you should not need to
|
||||
construct a DataType object directly. Instead, use the
|
||||
types.as_dtype() function.
|
||||
construct a `DataType` object directly. Instead, use the
|
||||
`tf.as_dtype()` function.
|
||||
|
||||
##### Args:
|
||||
|
||||
@ -1208,6 +1235,7 @@ types.as_dtype() function.
|
||||
|
||||
Returns whether this is a (real) floating point type.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.DType.max` {#DType.max}
|
||||
@ -1219,6 +1247,7 @@ Returns the maximum representable value in this data type.
|
||||
|
||||
* <b>`TypeError`</b>: if this is a non-numeric, unordered, or quantized type.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.DType.min` {#DType.min}
|
||||
@ -1231,6 +1260,7 @@ Returns the minimum representable value in this data type.
|
||||
* <b>`TypeError`</b>: if this is a non-numeric, unordered, or quantized type.
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
### `tf.as_dtype(type_value)` {#as_dtype}
|
||||
@ -1425,13 +1455,13 @@ protocol buffer, and extract individual objects in the `GraphDef` as
|
||||
##### Returns:
|
||||
|
||||
A list of `Operation` and/or `Tensor` objects from the imported graph,
|
||||
corresponding to the names in `return_elements'.
|
||||
corresponding to the names in `return_elements`.
|
||||
|
||||
##### Raises:
|
||||
|
||||
|
||||
* <b>`TypeError`</b>: If `graph_def` is not a `GraphDef` proto,
|
||||
`input_map' is not a dictionary mapping strings to `Tensor` objects,
|
||||
`input_map` is not a dictionary mapping strings to `Tensor` objects,
|
||||
or `return_elements` is not a list of strings.
|
||||
* <b>`ValueError`</b>: If `input_map`, or `return_elements` contains names that
|
||||
do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
|
||||
@ -1613,7 +1643,7 @@ that defines the operation.
|
||||
|
||||
#### `tf.RegisterShape.__init__(op_type)` {#RegisterShape.__init__}
|
||||
|
||||
Saves the "op_type" as the Operation type.
|
||||
Saves the `op_type` as the `Operation` type.
|
||||
|
||||
|
||||
|
||||
@ -1694,12 +1724,14 @@ information for use with slicing.
|
||||
|
||||
Returns the rank of this shape, or None if it is unspecified.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.TensorShape.dims` {#TensorShape.dims}
|
||||
|
||||
Returns a list of Dimensions, or None if the shape is unspecified.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.TensorShape.as_list()` {#TensorShape.as_list}
|
||||
@ -1916,7 +1948,7 @@ Creates a new TensorShape with the given dimensions.
|
||||
|
||||
#### `tf.TensorShape.as_dimension_list()` {#TensorShape.as_dimension_list}
|
||||
|
||||
DEPRECATED: use as_list().
|
||||
DEPRECATED: use `as_list()`.
|
||||
|
||||
|
||||
- - -
|
||||
@ -2014,6 +2046,7 @@ Dimensions are combined as follows:
|
||||
The value of this dimension, or None if it is unknown.
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
### `tf.op_scope(values, name, default_name)` {#op_scope}
|
||||
|
@ -425,7 +425,7 @@ Crops an image to a specified bounding box.
|
||||
This op cuts a rectangular part out of `image`. The top-left corner of the
|
||||
returned image is at `offset_height, offset_width` in `image`, and its
|
||||
lower-right corner is at
|
||||
`offset_height + target_height, offset_width + target_width'.
|
||||
`offset_height + target_height, offset_width + target_width`.
|
||||
|
||||
##### Args:
|
||||
|
||||
@ -724,7 +724,7 @@ have modifications in the range `[-max_delta,max_delta]`.
|
||||
##### Raises:
|
||||
|
||||
|
||||
* <b>`ValueError`</b>: if max_delta is negative.
|
||||
* <b>`ValueError`</b>: if `max_delta` is negative.
|
||||
|
||||
|
||||
|
||||
@ -831,3 +831,27 @@ Note that this implementation is limited:
|
||||
* <b>`ValueError`</b>: if the shape of 'image' is incompatible with this function.
|
||||
|
||||
|
||||
|
||||
## Other Functions and Classes
|
||||
- - -
|
||||
|
||||
### `tf.image.resize_nearest_neighbor_grad(grads, size, name=None)` {#resize_nearest_neighbor_grad}
|
||||
|
||||
Computes the gradient of nearest neighbor interpolation.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`grads`</b>: A `Tensor`. Must be one of the following types: `uint8`, `int8`, `int32`, `float32`, `float64`.
|
||||
4-D with shape `[batch, height, width, channels]`.
|
||||
* <b>`size`</b>: A 1-D int32 Tensor of 2 elements: `orig_height, orig_width`. The
|
||||
original input size.
|
||||
* <b>`name`</b>: A name for the operation (optional).
|
||||
|
||||
##### Returns:
|
||||
|
||||
A `Tensor`. Has the same type as `grads`.
|
||||
4-D with shape `[batch, orig_height, orig_width, channels]`. Gradients
|
||||
with respect to the input image.
|
||||
|
||||
|
||||
|
@ -215,6 +215,7 @@
|
||||
* [`resize_image_with_crop_or_pad`](../../api_docs/python/image.md#resize_image_with_crop_or_pad)
|
||||
* [`resize_images`](../../api_docs/python/image.md#resize_images)
|
||||
* [`resize_nearest_neighbor`](../../api_docs/python/image.md#resize_nearest_neighbor)
|
||||
* [`resize_nearest_neighbor_grad`](../../api_docs/python/image.md#resize_nearest_neighbor_grad)
|
||||
* [`transpose_image`](../../api_docs/python/image.md#transpose_image)
|
||||
|
||||
* **[Sparse Tensors](../../api_docs/python/sparse_ops.md)**:
|
||||
|
@ -153,6 +153,7 @@ finished with the previous file).
|
||||
|
||||
Op that implements the reader.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.ReaderBase.reset(name=None)` {#ReaderBase.reset}
|
||||
@ -216,6 +217,7 @@ Unimplemented error.
|
||||
Whether the Reader implementation can serialize its state.
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
### `class tf.TextLineReader` {#TextLineReader}
|
||||
@ -304,6 +306,7 @@ finished with the previous file).
|
||||
|
||||
Op that implements the reader.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.TextLineReader.reset(name=None)` {#TextLineReader.reset}
|
||||
@ -367,6 +370,7 @@ Unimplemented error.
|
||||
Whether the Reader implementation can serialize its state.
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
### `class tf.WholeFileReader` {#WholeFileReader}
|
||||
@ -455,6 +459,7 @@ finished with the previous file).
|
||||
|
||||
Op that implements the reader.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.WholeFileReader.reset(name=None)` {#WholeFileReader.reset}
|
||||
@ -518,6 +523,7 @@ Unimplemented error.
|
||||
Whether the Reader implementation can serialize its state.
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
### `class tf.IdentityReader` {#IdentityReader}
|
||||
@ -606,6 +612,7 @@ finished with the previous file).
|
||||
|
||||
Op that implements the reader.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.IdentityReader.reset(name=None)` {#IdentityReader.reset}
|
||||
@ -669,6 +676,7 @@ Unimplemented error.
|
||||
Whether the Reader implementation can serialize its state.
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
### `class tf.TFRecordReader` {#TFRecordReader}
|
||||
@ -754,6 +762,7 @@ finished with the previous file).
|
||||
|
||||
Op that implements the reader.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.TFRecordReader.reset(name=None)` {#TFRecordReader.reset}
|
||||
@ -817,6 +826,7 @@ Unimplemented error.
|
||||
Whether the Reader implementation can serialize its state.
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
### `class tf.FixedLengthRecordReader` {#FixedLengthRecordReader}
|
||||
@ -905,6 +915,7 @@ finished with the previous file).
|
||||
|
||||
Op that implements the reader.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.FixedLengthRecordReader.reset(name=None)` {#FixedLengthRecordReader.reset}
|
||||
@ -969,6 +980,7 @@ Whether the Reader implementation can serialize its state.
|
||||
|
||||
|
||||
|
||||
|
||||
## Converting
|
||||
|
||||
TensorFlow provides several operations that you can use to convert various data
|
||||
@ -1016,16 +1028,17 @@ Reinterpret the bytes of a string as a vector of numbers.
|
||||
All the elements must have the same length.
|
||||
* <b>`out_type`</b>: A `tf.DType` from: `tf.float32, tf.float64, tf.int32, tf.uint8, tf.int16, tf.int8, tf.int64`.
|
||||
* <b>`little_endian`</b>: An optional `bool`. Defaults to `True`.
|
||||
Whether the input bytes are in little-endian order.
|
||||
Ignored for out_types that are stored in a single byte like uint8.
|
||||
Whether the input `bytes` are in little-endian order.
|
||||
Ignored for `out_type` values that are stored in a single byte like
|
||||
`uint8`.
|
||||
* <b>`name`</b>: A name for the operation (optional).
|
||||
|
||||
##### Returns:
|
||||
|
||||
A `Tensor` of type `out_type`.
|
||||
A Tensor with one more dimension than the input bytes. The
|
||||
A Tensor with one more dimension than the input `bytes`. The
|
||||
added dimension will have size equal to the length of the elements
|
||||
of bytes divided by the number of bytes to represent out_type.
|
||||
of `bytes` divided by the number of bytes to represent `out_type`.
|
||||
|
||||
|
||||
|
||||
@ -1084,12 +1097,12 @@ serialized `Example`s are provided:
|
||||
|
||||
```
|
||||
serialized = [
|
||||
features:
|
||||
{ feature: [ key: { "ft" value: float_list: { value: [1.0, 2.0] } } ] },
|
||||
features:
|
||||
{ feature: [] },
|
||||
features:
|
||||
{ feature: [ key: { "ft" value: float_list: { value: [3.0] } } ] }
|
||||
features
|
||||
{ feature { key: "ft" value { float_list { value: [1.0, 2.0] } } } },
|
||||
features
|
||||
{ feature []},
|
||||
features
|
||||
{ feature { key: "ft" value { float_list { value: [3.0] } } }
|
||||
]
|
||||
```
|
||||
|
||||
@ -1105,14 +1118,14 @@ Given two `Example` input protos in `serialized`:
|
||||
|
||||
```
|
||||
[
|
||||
features: {
|
||||
feature: { key: "kw" value: { bytes_list: { value: [ "knit", "big" ] } } }
|
||||
feature: { key: "gps" value: { float_list: { value: [] } } }
|
||||
features {
|
||||
feature { key: "kw" value { bytes_list { value: [ "knit", "big" ] } } }
|
||||
feature { key: "gps" value { float_list { value: [] } } }
|
||||
},
|
||||
features: {
|
||||
feature: { key: "kw" value: { bytes_list: { value: [ "emmy" ] } } }
|
||||
feature: { key: "dank" value: { int64_list: { value: [ 42 ] } } }
|
||||
feature: { key: "gps" value: { } }
|
||||
features {
|
||||
feature { key: "kw" value { bytes_list { value: [ "emmy" ] } } }
|
||||
feature { key: "dank" value { int64_list { value: [ 42 ] } } }
|
||||
feature { key: "gps" value { } }
|
||||
}
|
||||
]
|
||||
```
|
||||
@ -1148,13 +1161,13 @@ For dense results in two serialized `Example`s:
|
||||
|
||||
```
|
||||
[
|
||||
features: {
|
||||
feature: { key: "age" value: { int64_list: { value: [ 0 ] } } }
|
||||
feature: { key: "gender" value: { bytes_list: { value: [ "f" ] } } }
|
||||
features {
|
||||
feature { key: "age" value { int64_list { value: [ 0 ] } } }
|
||||
feature { key: "gender" value { bytes_list { value: [ "f" ] } } }
|
||||
},
|
||||
features: {
|
||||
feature: { key: "age" value: { int64_list: { value: [] } } }
|
||||
feature: { key: "gender" value: { bytes_list: { value: [ "f" ] } } }
|
||||
features {
|
||||
feature { key: "age" value { int64_list { value: [] } } }
|
||||
feature { key: "gender" value { bytes_list { value: [ "f" ] } } }
|
||||
}
|
||||
]
|
||||
```
|
||||
@ -1202,6 +1215,8 @@ And the expected output is:
|
||||
The keys of the dict must match the dense_keys of the feature.
|
||||
* <b>`dense_shapes`</b>: A list of tuples with the same length as `dense_keys`.
|
||||
The shape of the data for each dense feature referenced by `dense_keys`.
|
||||
Required for any input tensors identified by dense_keys whose shapes are
|
||||
anything other than [] or [1].
|
||||
* <b>`name`</b>: A name for this operation (optional).
|
||||
|
||||
##### Returns:
|
||||
@ -1229,7 +1244,7 @@ same as the shape given in `dense_shape`.
|
||||
|
||||
For `SparseTensor`s, the first (batch) column of the indices matrix is removed
|
||||
(the indices matrix is a column vector), the values vector is unchanged, and
|
||||
the first (batch_size) entry of the shape vector is removed (it is now a
|
||||
the first (`batch_size`) entry of the shape vector is removed (it is now a
|
||||
single element vector).
|
||||
|
||||
See also `parse_example`.
|
||||
@ -1238,15 +1253,15 @@ See also `parse_example`.
|
||||
|
||||
|
||||
* <b>`serialized`</b>: A scalar string Tensor, a single serialized Example.
|
||||
See parse_example documentation for more details.
|
||||
See `parse_example` documentation for more details.
|
||||
* <b>`names`</b>: (Optional) A scalar string Tensor, the associated name.
|
||||
See parse_example documentation for more details.
|
||||
* <b>`sparse_keys`</b>: See parse_example documentation for more details.
|
||||
* <b>`sparse_types`</b>: See parse_example documentation for more details.
|
||||
* <b>`dense_keys`</b>: See parse_example documentation for more details.
|
||||
* <b>`dense_types`</b>: See parse_example documentation for more details.
|
||||
* <b>`dense_defaults`</b>: See parse_example documentation for more details.
|
||||
* <b>`dense_shapes`</b>: See parse_example documentation for more details.
|
||||
See `parse_example` documentation for more details.
|
||||
* <b>`sparse_keys`</b>: See `parse_example` documentation for more details.
|
||||
* <b>`sparse_types`</b>: See `parse_example` documentation for more details.
|
||||
* <b>`dense_keys`</b>: See `parse_example` documentation for more details.
|
||||
* <b>`dense_types`</b>: See `parse_example` documentation for more details.
|
||||
* <b>`dense_defaults`</b>: See `parse_example` documentation for more details.
|
||||
* <b>`dense_shapes`</b>: See `parse_example` documentation for more details.
|
||||
* <b>`name`</b>: A name for this operation (optional).
|
||||
|
||||
##### Returns:
|
||||
@ -1450,12 +1465,38 @@ Constructs a queue object from a queue reference.
|
||||
|
||||
The list of dtypes for each component of a queue element.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.QueueBase.from_list(index, queues)` {#QueueBase.from_list}
|
||||
|
||||
Create a queue using the queue reference from `queues[index]`.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`index`</b>: An integer scalar tensor that determines the input that gets
|
||||
selected.
|
||||
* <b>`queues`</b>: A list of `QueueBase` objects.
|
||||
|
||||
##### Returns:
|
||||
|
||||
A `QueueBase` object.
|
||||
|
||||
##### Raises:
|
||||
|
||||
|
||||
* <b>`TypeError`</b>: When `queues` is not a list of `QueueBase` objects,
|
||||
or when the data types of `queues` are not all the same.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.QueueBase.name` {#QueueBase.name}
|
||||
|
||||
The name of the underlying queue.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.QueueBase.queue_ref` {#QueueBase.queue_ref}
|
||||
@ -1463,6 +1504,7 @@ The name of the underlying queue.
|
||||
The underlying queue reference.
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
### `class tf.FIFOQueue` {#FIFOQueue}
|
||||
@ -1635,7 +1677,7 @@ Save the list of files matching pattern, so it is only computed once.
|
||||
|
||||
### `tf.train.limit_epochs(tensor, num_epochs=None, name=None)` {#limit_epochs}
|
||||
|
||||
Returns tensor num_epochs times and then raises an OutOfRange error.
|
||||
Returns tensor `num_epochs` times and then raises an `OutOfRange` error.
|
||||
|
||||
##### Args:
|
||||
|
||||
@ -1647,7 +1689,7 @@ Returns tensor num_epochs times and then raises an OutOfRange error.
|
||||
|
||||
##### Returns:
|
||||
|
||||
tensor or OutOfRange.
|
||||
tensor or `OutOfRange`.
|
||||
|
||||
|
||||
- - -
|
||||
@ -1773,6 +1815,12 @@ first dimension. If an input tensor has shape `[*, x, y, z]`, the
|
||||
output will have shape `[batch_size, x, y, z]`. The `capacity` argument
|
||||
controls the how long the prefetching is allowed to grow the queues.
|
||||
|
||||
The returned operation is a dequeue operation and will throw
|
||||
`tf.errors.OutOfRangeError` if the input queue is exhausted. If this
|
||||
operation is feeding another input queue, its queue runner will catch
|
||||
this exception, however, if this operation is used in your main thread
|
||||
you are responsible for catching this yourself.
|
||||
|
||||
*N.B.:* You must ensure that either (i) the `shapes` argument is
|
||||
passed, or (ii) all of the tensors in `tensor_list` must have
|
||||
fully-defined shapes. `ValueError` will be raised if neither of
|
||||
@ -1831,6 +1879,12 @@ same size in the first dimension. The slices of any input tensor
|
||||
The `capacity` argument controls the how long the prefetching is allowed to
|
||||
grow the queues.
|
||||
|
||||
The returned operation is a dequeue operation and will throw
|
||||
`tf.errors.OutOfRangeError` if the input queue is exhausted. If this
|
||||
operation is feeding another input queue, its queue runner will catch
|
||||
this exception, however, if this operation is used in your main thread
|
||||
you are responsible for catching this yourself.
|
||||
|
||||
*N.B.:* You must ensure that either (i) the `shapes` argument is
|
||||
passed, or (ii) all of the tensors in `tensor_list_list` must have
|
||||
fully-defined shapes. `ValueError` will be raised if neither of
|
||||
@ -1886,6 +1940,12 @@ output will have shape `[batch_size, x, y, z]`.
|
||||
The `capacity` argument controls the how long the prefetching is allowed to
|
||||
grow the queues.
|
||||
|
||||
The returned operation is a dequeue operation and will throw
|
||||
`tf.errors.OutOfRangeError` if the input queue is exhausted. If this
|
||||
operation is feeding another input queue, its queue runner will catch
|
||||
this exception, however, if this operation is used in your main thread
|
||||
you are responsible for catching this yourself.
|
||||
|
||||
For example:
|
||||
|
||||
```python
|
||||
@ -1961,10 +2021,11 @@ y, z]`, the output will have shape `[batch_size, x, y, z]`.
|
||||
The `capacity` argument controls the how long the prefetching is allowed to
|
||||
grow the queues.
|
||||
|
||||
*N.B.:* You must ensure that either (i) the `shapes` argument is
|
||||
passed, or (ii) all of the tensors in `tensor_list_list` must have
|
||||
fully-defined shapes. `ValueError` will be raised if neither of
|
||||
these conditions holds.
|
||||
The returned operation is a dequeue operation and will throw
|
||||
`tf.errors.OutOfRangeError` if the input queue is exhausted. If this
|
||||
operation is feeding another input queue, its queue runner will catch
|
||||
this exception, however, if this operation is used in your main thread
|
||||
you are responsible for catching this yourself.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
@ -590,7 +590,7 @@ tf.transpose(x) ==> [[1 4]
|
||||
[3 6]]
|
||||
|
||||
# Equivalently
|
||||
tf.transpose(x perm=[0, 1]) ==> [[1 4]
|
||||
tf.transpose(x, perm=[1, 0]) ==> [[1 4]
|
||||
[2 5]
|
||||
[3 6]]
|
||||
|
||||
@ -1355,7 +1355,7 @@ that `segment_ids[j] == i`.
|
||||
##### Returns:
|
||||
|
||||
A `Tensor`. Has the same type as `data`.
|
||||
Has same shape as data, except for dimension_0 which
|
||||
Has same shape as data, except for dimension 0 which
|
||||
has size `k`, the number of segments.
|
||||
|
||||
|
||||
@ -1389,7 +1389,7 @@ that `segment_ids[j] == i`.
|
||||
##### Returns:
|
||||
|
||||
A `Tensor`. Has the same type as `data`.
|
||||
Has same shape as data, except for dimension_0 which
|
||||
Has same shape as data, except for dimension 0 which
|
||||
has size `k`, the number of segments.
|
||||
|
||||
|
||||
@ -1423,7 +1423,7 @@ that `segment_ids[j] == i`.
|
||||
##### Returns:
|
||||
|
||||
A `Tensor`. Has the same type as `data`.
|
||||
Has same shape as data, except for dimension_0 which
|
||||
Has same shape as data, except for dimension 0 which
|
||||
has size `k`, the number of segments.
|
||||
|
||||
|
||||
@ -1456,7 +1456,7 @@ that `segment_ids[j] == i`.
|
||||
##### Returns:
|
||||
|
||||
A `Tensor`. Has the same type as `data`.
|
||||
Has same shape as data, except for dimension_0 which
|
||||
Has same shape as data, except for dimension 0 which
|
||||
has size `k`, the number of segments.
|
||||
|
||||
|
||||
@ -1491,7 +1491,7 @@ values summed.
|
||||
##### Returns:
|
||||
|
||||
A `Tensor`. Has the same type as `data`.
|
||||
Has same shape as data, except for dimension_0 which
|
||||
Has same shape as data, except for dimension 0 which
|
||||
has size `k`, the number of segments.
|
||||
|
||||
|
||||
@ -1533,7 +1533,7 @@ If the sum is empty for a given segment ID `i`, `output[i] = 0`.
|
||||
##### Returns:
|
||||
|
||||
A `Tensor`. Has the same type as `data`.
|
||||
Has same shape as data, except for dimension_0 which
|
||||
Has same shape as data, except for dimension 0 which
|
||||
has size `num_segments`.
|
||||
|
||||
|
||||
@ -1549,7 +1549,7 @@ Segmentation](../../api_docs/python/math_ops.md#segmentation) for an explanation
|
||||
of segments.
|
||||
|
||||
Like `SegmentSum`, but `segment_ids` can have rank less than `data`'s first
|
||||
dimension, selecting a subset of dimension_0, specified by `indices`.
|
||||
dimension, selecting a subset of dimension 0, specified by `indices`.
|
||||
|
||||
For example:
|
||||
|
||||
@ -1587,7 +1587,7 @@ tf.segment_sum(c, tf.constant([0, 0, 1]))
|
||||
##### Returns:
|
||||
|
||||
A `Tensor`. Has the same type as `data`.
|
||||
Has same shape as data, except for dimension_0 which
|
||||
Has same shape as data, except for dimension 0 which
|
||||
has size `k`, the number of segments.
|
||||
|
||||
|
||||
@ -1602,7 +1602,7 @@ Segmentation](../../api_docs/python/math_ops.md#segmentation) for an explanation
|
||||
of segments.
|
||||
|
||||
Like `SegmentMean`, but `segment_ids` can have rank less than `data`'s first
|
||||
dimension, selecting a subset of dimension_0, specified by `indices`.
|
||||
dimension, selecting a subset of dimension 0, specified by `indices`.
|
||||
|
||||
##### Args:
|
||||
|
||||
@ -1617,7 +1617,7 @@ dimension, selecting a subset of dimension_0, specified by `indices`.
|
||||
##### Returns:
|
||||
|
||||
A `Tensor`. Has the same type as `data`.
|
||||
Has same shape as data, except for dimension_0 which
|
||||
Has same shape as data, except for dimension 0 which
|
||||
has size `k`, the number of segments.
|
||||
|
||||
|
||||
|
@ -720,7 +720,7 @@ tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.
|
||||
|
||||
|
||||
* <b>`params`</b>: A list of tensors with the same shape and type.
|
||||
* <b>`ids`</b>: A `Tensor` with type `int32` containing the ids to be looked
|
||||
* <b>`ids`</b>: A `Tensor` with type `int32` or `int64` containing the ids to be looked
|
||||
up in `params`.
|
||||
* <b>`name`</b>: A name for the operation (optional).
|
||||
|
||||
@ -850,17 +850,19 @@ with an otherwise unused class.
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`weights`</b>: A `Tensor` of shape [num_classes, dim]. The class embeddings.
|
||||
* <b>`biases`</b>: A `Tensor` of shape [num_classes]. The class biases.
|
||||
* <b>`inputs`</b>: A `Tensor` of shape [batch_size, dim]. The forward
|
||||
* <b>`weights`</b>: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor`
|
||||
objects whose concatenation along dimension 0 has shape
|
||||
[num_classes, dim]. The (possibly-sharded) class embeddings.
|
||||
* <b>`biases`</b>: A `Tensor` of shape `[num_classes]`. The class biases.
|
||||
* <b>`inputs`</b>: A `Tensor` of shape `[batch_size, dim]`. The forward
|
||||
activations of the input network.
|
||||
* <b>`labels`</b>: A `Tensor` of type `int64` and shape `[batch_size,
|
||||
num_true]`. The target classes.
|
||||
* <b>`num_sampled`</b>: An `int`. The number of classes to randomly sample per batch.
|
||||
* <b>`num_classes`</b>: An `int`. The number of possible classes.
|
||||
* <b>`num_true`</b>: An `int`. The number of target classes per training example.
|
||||
* <b>`sampled_values`</b>: a tuple of `(sampled_candidates, true_expected_count,
|
||||
sampled_expected_count)` returned by a `*_candidate_sampler` function.
|
||||
* <b>`sampled_values`</b>: a tuple of (`sampled_candidates`, `true_expected_count`,
|
||||
`sampled_expected_count`) returned by a `*_candidate_sampler` function.
|
||||
(if None, we default to `log_uniform_candidate_sampler`)
|
||||
* <b>`remove_accidental_hits`</b>: A `bool`. Whether to remove "accidental hits"
|
||||
where a sampled class equals one of the target classes. If set to
|
||||
@ -873,7 +875,7 @@ with an otherwise unused class.
|
||||
|
||||
##### Returns:
|
||||
|
||||
A batch_size 1-D tensor of per-example NCE losses.
|
||||
A `batch_size` 1-D tensor of per-example NCE losses.
|
||||
|
||||
|
||||
- - -
|
||||
@ -899,9 +901,11 @@ Also see Section 3 of http://arxiv.org/abs/1412.2007 for the math.
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`weights`</b>: A `Tensor` of shape [num_classes, dim]. The class embeddings.
|
||||
* <b>`biases`</b>: A `Tensor` of shape [num_classes]. The class biases.
|
||||
* <b>`inputs`</b>: A `Tensor` of shape [batch_size, dim]. The forward
|
||||
* <b>`weights`</b>: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor`
|
||||
objects whose concatenation along dimension 0 has shape
|
||||
[num_classes, dim]. The (possibly-sharded) class embeddings.
|
||||
* <b>`biases`</b>: A `Tensor` of shape `[num_classes]`. The class biases.
|
||||
* <b>`inputs`</b>: A `Tensor` of shape `[batch_size, dim]`. The forward
|
||||
activations of the input network.
|
||||
* <b>`labels`</b>: A `Tensor` of type `int64` and shape `[batch_size,
|
||||
num_true]`. The target classes. Note that this format differs from
|
||||
@ -909,8 +913,8 @@ Also see Section 3 of http://arxiv.org/abs/1412.2007 for the math.
|
||||
* <b>`num_sampled`</b>: An `int`. The number of classes to randomly sample per batch.
|
||||
* <b>`num_classes`</b>: An `int`. The number of possible classes.
|
||||
* <b>`num_true`</b>: An `int`. The number of target classes per training example.
|
||||
* <b>`sampled_values`</b>: a tuple of `(sampled_candidates, true_expected_count,
|
||||
sampled_expected_count)` returned by a `*_candidate_sampler` function.
|
||||
* <b>`sampled_values`</b>: a tuple of (`sampled_candidates`, `true_expected_count`,
|
||||
`sampled_expected_count`) returned by a `*_candidate_sampler` function.
|
||||
(if None, we default to `log_uniform_candidate_sampler`)
|
||||
* <b>`remove_accidental_hits`</b>: A `bool`. whether to remove "accidental hits"
|
||||
where a sampled class equals one of the target classes. Default is
|
||||
@ -919,7 +923,7 @@ Also see Section 3 of http://arxiv.org/abs/1412.2007 for the math.
|
||||
|
||||
##### Returns:
|
||||
|
||||
A batch_size 1-D tensor of per-example sampled softmax losses.
|
||||
A `batch_size` 1-D tensor of per-example sampled softmax losses.
|
||||
|
||||
|
||||
|
||||
@ -1180,7 +1184,7 @@ compute them approximately.
|
||||
|
||||
### `tf.nn.compute_accidental_hits(true_classes, sampled_candidates, num_true, seed=None, name=None)` {#compute_accidental_hits}
|
||||
|
||||
Compute the ids of positions in sampled_candidates matching true_classes.
|
||||
Compute the position ids in `sampled_candidates` matching `true_classes`.
|
||||
|
||||
In Candidate Sampling, this operation facilitates virtually removing
|
||||
sampled classes which happen to match target classes. This is done
|
||||
|
@ -91,6 +91,7 @@ The indices of non-zero values in the represented dense tensor.
|
||||
A 2-D Tensor of int64 with shape `[N, ndims]`, where `N` is the
|
||||
number of non-zero values in the tensor, and `ndims` is the rank.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.SparseTensor.values` {#SparseTensor.values}
|
||||
@ -101,18 +102,21 @@ The non-zero values in the represented dense tensor.
|
||||
|
||||
A 1-D Tensor of any data type.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.SparseTensor.dtype` {#SparseTensor.dtype}
|
||||
|
||||
The `DType` of elements in this tensor.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.SparseTensor.shape` {#SparseTensor.shape}
|
||||
|
||||
A 1-D Tensor of int64 representing the shape of the dense tensor.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.SparseTensor.graph` {#SparseTensor.graph}
|
||||
@ -120,6 +124,7 @@ A 1-D Tensor of int64 representing the shape of the dense tensor.
|
||||
The `Graph` that contains the index, value, and shape tensors.
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
### `class tf.SparseTensorValue` {#SparseTensorValue}
|
||||
@ -131,12 +136,14 @@ SparseTensorValue(indices, values, shape)
|
||||
|
||||
Alias for field number 0
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.SparseTensorValue.shape` {#SparseTensorValue.shape}
|
||||
|
||||
Alias for field number 2
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.SparseTensorValue.values` {#SparseTensorValue.values}
|
||||
@ -145,6 +152,7 @@ Alias for field number 1
|
||||
|
||||
|
||||
|
||||
|
||||
## Sparse to Dense Conversion
|
||||
|
||||
- - -
|
||||
@ -363,7 +371,7 @@ is during manual manipulation of the indices and values to add entries.
|
||||
|
||||
Reordering does not affect the shape of the `SparseTensor`.
|
||||
|
||||
For example, if sp_input has shape `[4, 5]` and `indices` / `values`:
|
||||
For example, if `sp_input` has shape `[4, 5]` and `indices` / `values`:
|
||||
|
||||
[0, 3]: b
|
||||
[0, 1]: a
|
||||
|
@ -332,12 +332,14 @@ Properties.
|
||||
|
||||
The name of this variable.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.Variable.dtype` {#Variable.dtype}
|
||||
|
||||
The `DType` of this variable.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.Variable.get_shape()` {#Variable.get_shape}
|
||||
@ -355,18 +357,21 @@ The `TensorShape` of this variable.
|
||||
|
||||
The device of this variable.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.Variable.initializer` {#Variable.initializer}
|
||||
|
||||
The initializer operation for this variable.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.Variable.graph` {#Variable.graph}
|
||||
|
||||
The `Graph` of this variable.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.Variable.op` {#Variable.op}
|
||||
@ -375,6 +380,7 @@ The `Operation` of this variable.
|
||||
|
||||
|
||||
|
||||
|
||||
## Variable helper functions
|
||||
|
||||
TensorFlow provides a set of functions to help manage the set of variables
|
||||
@ -587,43 +593,43 @@ saver = tf.train.Saver([v1, v2])
|
||||
saver = tf.train.Saver({v.op.name: v for v in [v1, v2]})
|
||||
```
|
||||
|
||||
The optional `reshape` argument, if True, allows restoring a variable from
|
||||
The optional `reshape` argument, if `True`, allows restoring a variable from
|
||||
a save file where the variable had a different shape, but the same number
|
||||
of elements and type. This is useful if you have reshaped a variable and
|
||||
want to reload it from an older checkpoint.
|
||||
|
||||
The optional `sharded` argument, if True, instructs the saver to shard
|
||||
The optional `sharded` argument, if `True`, instructs the saver to shard
|
||||
checkpoints per device.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`var_list`</b>: A list of Variables or a dictionary mapping names to
|
||||
Variables. If None, defaults to the list of all variables.
|
||||
* <b>`reshape`</b>: If True, allows restoring parameters from a checkpoint
|
||||
* <b>`var_list`</b>: A list of `Variable` objects or a dictionary mapping names to
|
||||
variables. If `None`, defaults to the list of all variables.
|
||||
* <b>`reshape`</b>: If `True`, allows restoring parameters from a checkpoint
|
||||
where the variables have a different shape.
|
||||
* <b>`sharded`</b>: If True, shard the checkpoints, one per device.
|
||||
* <b>`sharded`</b>: If `True`, shard the checkpoints, one per device.
|
||||
* <b>`max_to_keep`</b>: maximum number of recent checkpoints to keep.
|
||||
Defaults to 10,000 hours.
|
||||
* <b>`keep_checkpoint_every_n_hours`</b>: How often to keep checkpoints.
|
||||
Defaults to 10,000 hours.
|
||||
* <b>`name`</b>: string. Optional name to use as a prefix when adding operations.
|
||||
* <b>`restore_sequentially`</b>: A Bool, which if true, causes restore of different
|
||||
* <b>`restore_sequentially`</b>: A `Bool`, which if true, causes restore of different
|
||||
variables to happen sequentially within each device. This can lower
|
||||
memory usage when restoring very large models.
|
||||
* <b>`saver_def`</b>: Optional SaverDef proto to use instead of running the builder.
|
||||
This is only useful for specialty code that wants to recreate a Saver
|
||||
object for a previously built Graph that had a Saver. The saver_def
|
||||
proto should be the one returned by the as_saver_def() call of the
|
||||
Saver that was created for that Graph.
|
||||
* <b>`builder`</b>: Optional SaverBuilder to use if a saver_def was not provided.
|
||||
Defaults to BaseSaverBuilder().
|
||||
* <b>`saver_def`</b>: Optional `SaverDef` proto to use instead of running the
|
||||
builder. This is only useful for specialty code that wants to recreate
|
||||
a `Saver` object for a previously built `Graph` that had a `Saver`.
|
||||
The `saver_def` proto should be the one returned by the
|
||||
`as_saver_def()` call of the `Saver` that was created for that `Graph`.
|
||||
* <b>`builder`</b>: Optional `SaverBuilder` to use if a `saver_def` was not provided.
|
||||
Defaults to `BaseSaverBuilder()`.
|
||||
|
||||
##### Raises:
|
||||
|
||||
|
||||
* <b>`TypeError`</b>: If `var_list` is invalid.
|
||||
* <b>`ValueError`</b>: If any of the keys or values in `var_list` is not unique.
|
||||
* <b>`ValueError`</b>: If any of the keys or values in `var_list` are not unique.
|
||||
|
||||
|
||||
- - -
|
||||
@ -647,7 +653,7 @@ path can be passed directly to a call to `restore()`.
|
||||
`sharded`, this is the prefix of the sharded checkpoint filename.
|
||||
* <b>`global_step`</b>: If provided the global step number is appended to
|
||||
`save_path` to create the checkpoint filename. The optional argument
|
||||
can be a Tensor, a Tensor name or an integer.
|
||||
can be a `Tensor`, a `Tensor` name or an integer.
|
||||
* <b>`latest_filename`</b>: Optional name for the protocol buffer file that will
|
||||
contains the list of most recent checkpoint filenames. That file,
|
||||
kept in the same directory as the checkpoint files, is automatically
|
||||
@ -663,7 +669,7 @@ path can be passed directly to a call to `restore()`.
|
||||
##### Raises:
|
||||
|
||||
|
||||
* <b>`TypeError`</b>: If `sess` is not a Session.
|
||||
* <b>`TypeError`</b>: If `sess` is not a `Session`.
|
||||
|
||||
|
||||
- - -
|
||||
@ -683,7 +689,7 @@ The `save_path` argument is typically a value previously returned from a
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`sess`</b>: A Session to use to restore the parameters.
|
||||
* <b>`sess`</b>: A `Session` to use to restore the parameters.
|
||||
* <b>`save_path`</b>: Path where parameters were previously saved.
|
||||
|
||||
|
||||
@ -702,21 +708,22 @@ You can pass any of the returned values to `restore()`.
|
||||
|
||||
A list of checkpoint filenames, sorted from oldest to newest.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.train.Saver.set_last_checkpoints(last_checkpoints)` {#Saver.set_last_checkpoints}
|
||||
|
||||
Sets the list of not-yet-deleted checkpoint filenames.
|
||||
Sets the list of old checkpoint filenames.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`last_checkpoints`</b>: a list of checkpoint filenames.
|
||||
* <b>`last_checkpoints`</b>: A list of checkpoint filenames.
|
||||
|
||||
##### Raises:
|
||||
|
||||
|
||||
* <b>`AssertionError`</b>: if the list of checkpoint filenames has already been set.
|
||||
* <b>`AssertionError`</b>: If the list of checkpoint filenames has already been set.
|
||||
|
||||
|
||||
- - -
|
||||
@ -748,7 +755,7 @@ Finds the filename of latest saved checkpoint file.
|
||||
|
||||
##### Returns:
|
||||
|
||||
The full path to the latest checkpoint or None if no checkpoint was found.
|
||||
The full path to the latest checkpoint or `None` if no checkpoint was found.
|
||||
|
||||
|
||||
|
||||
@ -1319,12 +1326,14 @@ Creates an `IndexedSlices`.
|
||||
|
||||
A `Tensor` containing the values of the slices.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.IndexedSlices.indices` {#IndexedSlices.indices}
|
||||
|
||||
A 1-D `Tensor` containing the indices of the slices.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.IndexedSlices.dense_shape` {#IndexedSlices.dense_shape}
|
||||
@ -1332,24 +1341,28 @@ A 1-D `Tensor` containing the indices of the slices.
|
||||
A 1-D `Tensor` containing the shape of the corresponding dense tensor.
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.IndexedSlices.name` {#IndexedSlices.name}
|
||||
|
||||
The name of this `IndexedSlices`.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.IndexedSlices.dtype` {#IndexedSlices.dtype}
|
||||
|
||||
The `DType` of elements in this tensor.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.IndexedSlices.device` {#IndexedSlices.device}
|
||||
|
||||
The name of the device on which `values` will be produced, or `None`.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.IndexedSlices.op` {#IndexedSlices.op}
|
||||
@ -1357,3 +1370,4 @@ The name of the device on which `values` will be produced, or `None`.
|
||||
The `Operation` that produces `values` as an output.
|
||||
|
||||
|
||||
|
||||
|
@ -26,7 +26,7 @@ class directly, but instead instantiate one of its subclasses such as
|
||||
|
||||
### Usage
|
||||
|
||||
```
|
||||
```python
|
||||
# Create an optimizer with the desired parameters.
|
||||
opt = GradientDescentOptimizer(learning_rate=0.1)
|
||||
# Add Ops to the graph to minimize a cost by updating a list of variables.
|
||||
@ -37,7 +37,7 @@ opt_op = opt.minimize(cost, <list of variables>)
|
||||
|
||||
In the training program you will just have to run the returned Op.
|
||||
|
||||
```
|
||||
```python
|
||||
# Execute opt_op to do one step of training:
|
||||
opt_op.run()
|
||||
```
|
||||
@ -54,7 +54,7 @@ before applying them you can instead use the optimizer in three steps:
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
```python
|
||||
# Create an optimizer.
|
||||
opt = GradientDescentOptimizer(learning_rate=0.1)
|
||||
|
||||
@ -96,49 +96,49 @@ This must be called by the constructors of subclasses.
|
||||
|
||||
#### `tf.train.Optimizer.minimize(loss, global_step=None, var_list=None, gate_gradients=1, aggregation_method=None, name=None)` {#Optimizer.minimize}
|
||||
|
||||
Add operations to minimize 'loss' by updating 'var_list'.
|
||||
Add operations to minimize `loss` by updating `var_list`.
|
||||
|
||||
This method simply combines calls compute_gradients() and
|
||||
apply_gradients(). If you want to process the gradient before applying them
|
||||
call compute_gradients() and apply_gradients() explicitly instead of using
|
||||
this function.
|
||||
This method simply combines calls `compute_gradients()` and
|
||||
`apply_gradients()`. If you want to process the gradient before applying
|
||||
them call `compute_gradients()` and `apply_gradients()` explicitly instead
|
||||
of using this function.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`loss`</b>: A Tensor containing the value to minimize.
|
||||
* <b>`global_step`</b>: Optional Variable to increment by one after the
|
||||
* <b>`loss`</b>: A `Tensor` containing the value to minimize.
|
||||
* <b>`global_step`</b>: Optional `Variable` to increment by one after the
|
||||
variables have been updated.
|
||||
* <b>`var_list`</b>: Optional list of variables.Variable to update to minimize
|
||||
'loss'. Defaults to the list of variables collected in the graph
|
||||
under the key GraphKeys.TRAINABLE_VARIABLES.
|
||||
* <b>`var_list`</b>: Optional list of `Variable` objects to update to minimize
|
||||
`loss`. Defaults to the list of variables collected in the graph
|
||||
under the key `GraphKeys.TRAINABLE_VARIABLES`.
|
||||
* <b>`gate_gradients`</b>: How to gate the computation of gradients. Can be
|
||||
GATE_NONE, GATE_OP, or GATE_GRAPH.
|
||||
`GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`.
|
||||
* <b>`aggregation_method`</b>: Specifies the method used to combine gradient terms.
|
||||
Valid values are defined in the class `AggregationMethod`.
|
||||
* <b>`name`</b>: Optional name for the returned operation.
|
||||
|
||||
##### Returns:
|
||||
|
||||
An Operation that updates the variables in 'var_list'. If 'global_step'
|
||||
was not None, that operation also increments global_step.
|
||||
An Operation that updates the variables in `var_list`. If `global_step`
|
||||
was not `None`, that operation also increments `global_step`.
|
||||
|
||||
##### Raises:
|
||||
|
||||
|
||||
* <b>`ValueError`</b>: if some of the variables are not variables.Variable objects.
|
||||
* <b>`ValueError`</b>: if some of the variables are not `Variable` objects.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.train.Optimizer.compute_gradients(loss, var_list=None, gate_gradients=1, aggregation_method=None)` {#Optimizer.compute_gradients}
|
||||
|
||||
Compute gradients of "loss" for the variables in "var_list".
|
||||
Compute gradients of `loss` for the variables in `var_list`.
|
||||
|
||||
This is the first part of minimize(). It returns a list
|
||||
This is the first part of `minimize()`. It returns a list
|
||||
of (gradient, variable) pairs where "gradient" is the gradient
|
||||
for "variable". Note that "gradient" can be a Tensor, a
|
||||
IndexedSlices, or None if there is no gradient for the
|
||||
for "variable". Note that "gradient" can be a `Tensor`, an
|
||||
`IndexedSlices`, or `None` if there is no gradient for the
|
||||
given variable.
|
||||
|
||||
##### Args:
|
||||
@ -146,10 +146,10 @@ given variable.
|
||||
|
||||
* <b>`loss`</b>: A Tensor containing the value to minimize.
|
||||
* <b>`var_list`</b>: Optional list of variables.Variable to update to minimize
|
||||
"loss". Defaults to the list of variables collected in the graph
|
||||
under the key GraphKey.TRAINABLE_VARIABLES.
|
||||
`loss`. Defaults to the list of variables collected in the graph
|
||||
under the key `GraphKey.TRAINABLE_VARIABLES`.
|
||||
* <b>`gate_gradients`</b>: How to gate the computation of gradients. Can be
|
||||
GATE_NONE, GATE_OP, or GATE_GRAPH.
|
||||
`GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`.
|
||||
* <b>`aggregation_method`</b>: Specifies the method used to combine gradient terms.
|
||||
Valid values are defined in the class `AggregationMethod`.
|
||||
|
||||
@ -160,7 +160,7 @@ given variable.
|
||||
##### Raises:
|
||||
|
||||
|
||||
* <b>`TypeError`</b>: If var_list contains anything else than variables.Variable.
|
||||
* <b>`TypeError`</b>: If `var_list` contains anything else than `Variable` objects.
|
||||
* <b>`ValueError`</b>: If some arguments are invalid.
|
||||
|
||||
|
||||
@ -170,28 +170,28 @@ given variable.
|
||||
|
||||
Apply gradients to variables.
|
||||
|
||||
This is the second part of minimize(). It returns an Operation that
|
||||
This is the second part of `minimize()`. It returns an `Operation` that
|
||||
applies gradients.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`grads_and_vars`</b>: List of (gradient, variable) pairs as returned by
|
||||
compute_gradients().
|
||||
* <b>`global_step`</b>: Optional Variable to increment by one after the
|
||||
`compute_gradients()`.
|
||||
* <b>`global_step`</b>: Optional `Variable` to increment by one after the
|
||||
variables have been updated.
|
||||
* <b>`name`</b>: Optional name for the returned operation. Default to the
|
||||
name passed to the Optimizer constructor.
|
||||
name passed to the `Optimizer` constructor.
|
||||
|
||||
##### Returns:
|
||||
|
||||
An Operation that applies the specified gradients. If 'global_step'
|
||||
was not None, that operation also increments global_step.
|
||||
An `Operation` that applies the specified gradients. If `global_step`
|
||||
was not None, that operation also increments `global_step`.
|
||||
|
||||
##### Raises:
|
||||
|
||||
|
||||
* <b>`TypeError`</b>: if grads_and_vars is malformed.
|
||||
* <b>`TypeError`</b>: if `grads_and_vars` is malformed.
|
||||
|
||||
|
||||
|
||||
@ -203,18 +203,18 @@ gradients.
|
||||
|
||||
The possible values are: `GATE_NONE`, `GATE_OP`, and `GATE_GRAPH`.
|
||||
|
||||
<b>GATE_NONE</b>: Compute and apply gradients in parallel. This provides the
|
||||
maximum parallelism in execution, at the cost of some non-reproducibility in
|
||||
the results. For example the two gradients of MatMul depend on the input
|
||||
<b>`GATE_NONE`</b>: Compute and apply gradients in parallel. This provides
|
||||
the maximum parallelism in execution, at the cost of some non-reproducibility
|
||||
in the results. For example the two gradients of `matmul` depend on the input
|
||||
values: With `GATE_NONE` one of the gradients could be applied to one of the
|
||||
inputs _before_ the other gradient is computed resulting in non-reproducible
|
||||
results.
|
||||
|
||||
<b>GATE_OP</b>: For each Op, make sure all gradients are computed before they
|
||||
are used. This prevents race conditions for Ops that generate gradients for
|
||||
multiple inputs where the gradients depend on the inputs.
|
||||
<b>`GATE_OP`</b>: For each Op, make sure all gradients are computed before
|
||||
they are used. This prevents race conditions for Ops that generate gradients
|
||||
for multiple inputs where the gradients depend on the inputs.
|
||||
|
||||
<b>GATE_GRAPH</b>: Make sure all gradients for all variables are computed
|
||||
<b>`GATE_GRAPH`</b>: Make sure all gradients for all variables are computed
|
||||
before any one of them is used. This provides the least parallelism but can
|
||||
be useful if you want to process all gradients before applying any of them.
|
||||
|
||||
@ -233,9 +233,9 @@ about the slots, etc.
|
||||
|
||||
#### `tf.train.Optimizer.get_slot_names()` {#Optimizer.get_slot_names}
|
||||
|
||||
Return a list of the names of slots created by the Optimizer.
|
||||
Return a list of the names of slots created by the `Optimizer`.
|
||||
|
||||
See get_slot().
|
||||
See `get_slot()`.
|
||||
|
||||
##### Returns:
|
||||
|
||||
@ -246,23 +246,24 @@ See get_slot().
|
||||
|
||||
#### `tf.train.Optimizer.get_slot(var, name)` {#Optimizer.get_slot}
|
||||
|
||||
Return a slot named "name" created for "var" by the Optimizer.
|
||||
Return a slot named `name` created for `var` by the Optimizer.
|
||||
|
||||
Some Optimizer subclasses use additional variables. For example
|
||||
Momentum and Adagrad use variables to accumulate updates. This method
|
||||
gives access to these Variables if for some reason you need them.
|
||||
Some `Optimizer` subclasses use additional variables. For example
|
||||
`Momentum` and `Adagrad` use variables to accumulate updates. This method
|
||||
gives access to these `Variable` objects if for some reason you need them.
|
||||
|
||||
Use get_slot_names() to get the list of slot names created by the Optimizer.
|
||||
Use `get_slot_names()` to get the list of slot names created by the
|
||||
`Optimizer`.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`var`</b>: A variable passed to minimize() or apply_gradients().
|
||||
* <b>`var`</b>: A variable passed to `minimize()` or `apply_gradients()`.
|
||||
* <b>`name`</b>: A string.
|
||||
|
||||
##### Returns:
|
||||
|
||||
The Variable for the slot if it was created, None otherwise.
|
||||
The `Variable` for the slot if it was created, `None` otherwise.
|
||||
|
||||
|
||||
|
||||
@ -315,7 +316,7 @@ Construct a new Adagrad optimizer.
|
||||
##### Raises:
|
||||
|
||||
|
||||
* <b>`ValueError`</b>: If the initial_accumulator_value is invalid.
|
||||
* <b>`ValueError`</b>: If the `initial_accumulator_value` is invalid.
|
||||
|
||||
|
||||
|
||||
@ -505,7 +506,7 @@ for y in `ys`.
|
||||
`grad_ys` is a list of tensors of the same length as `ys` that holds
|
||||
the initial gradients for each y in `ys`. When `grad_ys` is None,
|
||||
we fill in a tensor of '1's of the shape of y for each y in `ys`. A
|
||||
user can provide their own initial 'grad_ys` to compute the
|
||||
user can provide their own initial `grad_ys` to compute the
|
||||
derivatives using a different initial gradient for each y (e.g., if
|
||||
one wanted to weight the gradient differently for each value in
|
||||
each y).
|
||||
@ -631,7 +632,7 @@ greater than `clip_value_max` are set to `clip_value_max`.
|
||||
Clips tensor values to a maximum L2-norm.
|
||||
|
||||
Given a tensor `t`, and a maximum clip value `clip_norm`, this operation
|
||||
normalizes `t` so that its L2-norm is less than or equal to `clip_norm'.
|
||||
normalizes `t` so that its L2-norm is less than or equal to `clip_norm`.
|
||||
Specifically, if the L2-norm is already less than or equal to `clip_norm`,
|
||||
then `t` is not modified. If the L2-norm is greater than `clip_norm`, then
|
||||
this operation returns a tensor of the same type and shape as `t` with its
|
||||
@ -664,7 +665,7 @@ Clips tensor values to a maximum average L2-norm.
|
||||
|
||||
Given a tensor `t`, and a maximum clip value `clip_norm`, this operation
|
||||
normalizes `t` so that its average L2-norm is less than or equal to
|
||||
`clip_norm'. Specifically, if the average L2-norm is already less than or
|
||||
`clip_norm`. Specifically, if the average L2-norm is already less than or
|
||||
equal to `clip_norm`, then `t` is not modified. If the average L2-norm is
|
||||
greater than `clip_norm`, then this operation returns a tensor of the same
|
||||
type and shape as `t` with its values set to:
|
||||
@ -700,18 +701,18 @@ and the global norm (`global_norm`) of all tensors in `t_list`. Optionally,
|
||||
if you've already computed the global norm for `t_list`, you can specify
|
||||
the global norm with `use_norm`.
|
||||
|
||||
To perform the clipping, the values t_list[i] are set to:
|
||||
To perform the clipping, the values `t_list[i]` are set to:
|
||||
|
||||
`t_list[i] * clip_norm / max(global_norm, clip_norm)`
|
||||
t_list[i] * clip_norm / max(global_norm, clip_norm)
|
||||
|
||||
where:
|
||||
|
||||
`global_norm = sqrt(sum([l2norm(t)**2 for t in t_list]))`
|
||||
global_norm = sqrt(sum([l2norm(t)**2 for t in t_list]))
|
||||
|
||||
If `clip_norm > global_norm` then the entries in `t_list` remain as they are,
|
||||
otherwise they're all shrunk by the global ratio.
|
||||
|
||||
Any of the entries of `t_list` that are of type None are ignored.
|
||||
Any of the entries of `t_list` that are of type `None` are ignored.
|
||||
|
||||
This is the correct way to perform gradient clipping (for example, see
|
||||
R. Pascanu, T. Mikolov, and Y. Bengio, "On the difficulty of training
|
||||
@ -1136,28 +1137,28 @@ Create a new Coordinator.
|
||||
|
||||
Wait for threads to terminate.
|
||||
|
||||
Blocks until all 'threads' have terminated or request_stop() is called.
|
||||
Blocks until all `threads` have terminated or `request_stop()` is called.
|
||||
|
||||
After the threads stop, if an 'exc_info' was passed to request_stop, that
|
||||
After the threads stop, if an `exc_info` was passed to `request_stop`, that
|
||||
exception is re-reaised.
|
||||
|
||||
Grace period handling: When request_stop() is called, threads are given
|
||||
Grace period handling: When `request_stop()` is called, threads are given
|
||||
'stop_grace_period_secs' seconds to terminate. If any of them is still
|
||||
alive after that period expires, a RuntimeError is raised. Note that if
|
||||
an 'exc_info' was passed to request_stop() then it is raised instead of
|
||||
that RuntimeError.
|
||||
alive after that period expires, a `RuntimeError` is raised. Note that if
|
||||
an `exc_info` was passed to `request_stop()` then it is raised instead of
|
||||
that `RuntimeError`.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`threads`</b>: List threading.Threads. The started threads to join.
|
||||
* <b>`threads`</b>: List of `threading.Threads`. The started threads to join.
|
||||
* <b>`stop_grace_period_secs`</b>: Number of seconds given to threads to stop after
|
||||
request_stop() has been called.
|
||||
`request_stop()` has been called.
|
||||
|
||||
##### Raises:
|
||||
|
||||
|
||||
* <b>`RuntimeError`</b>: If any thread is still alive after request_stop()
|
||||
* <b>`RuntimeError`</b>: If any thread is still alive after `request_stop()`
|
||||
is called and the grace period expires.
|
||||
|
||||
|
||||
@ -1167,14 +1168,14 @@ that RuntimeError.
|
||||
|
||||
Request that the threads stop.
|
||||
|
||||
After this is called, calls to should_stop() will return True.
|
||||
After this is called, calls to `should_stop()` will return `True`.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`ex`</b>: Optional Exception, or Python 'exc_info' tuple as returned by
|
||||
sys.exc_info(). If this is the first call to request_stop() the
|
||||
corresponding exception is recorded and re-raised from join().
|
||||
* <b>`ex`</b>: Optional `Exception`, or Python `exc_info` tuple as returned by
|
||||
`sys.exc_info()`. If this is the first call to `request_stop()` the
|
||||
corresponding exception is recorded and re-raised from `join()`.
|
||||
|
||||
|
||||
- - -
|
||||
@ -1306,6 +1307,7 @@ depending on whether or not a `Coordinator` was passed to
|
||||
was captured. (No exceptions are captured when using a Coordinator.)
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
### `tf.train.add_queue_runner(qr, collection='queue_runners')` {#add_queue_runner}
|
||||
@ -1768,7 +1770,7 @@ global_step: 10
|
||||
|
||||
Writes a graph proto on disk.
|
||||
|
||||
The graph is written as a binary proto unless as_text is `True`.
|
||||
The graph is written as a binary proto unless `as_text` is `True`.
|
||||
|
||||
```python
|
||||
v = tf.Variable(0, name='my_variable')
|
||||
|
Before ![]() (image error) Size: 264 KiB |
@ -2,45 +2,47 @@
|
||||
|
||||
Let's get you up and running with TensorFlow!
|
||||
|
||||
But before we even get started, let's give you a sneak peek at what TensorFlow
|
||||
code looks like in the Python API, just so you have a sense of where we're
|
||||
But before we even get started, let's peek at what TensorFlow
|
||||
code looks like in the Python API, so you have a sense of where we're
|
||||
headed.
|
||||
|
||||
Here's a little Python program that makes up some data in three dimensions, and
|
||||
then fits a plane to it.
|
||||
Here's a little Python program that makes up some data in two dimensions, and
|
||||
then fits a line to it.
|
||||
|
||||
```python
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
# Make 100 phony data points in NumPy.
|
||||
x_data = np.float32(np.random.rand(2, 100)) # Random input
|
||||
y_data = np.dot([0.100, 0.200], x_data) + 0.300
|
||||
# Create 100 phony x, y data points in NumPy, y = x * 0.1 + 0.3
|
||||
x_data = np.random.rand(100).astype("float32")
|
||||
y_data = x_data * 0.1 + 0.3
|
||||
|
||||
# Construct a linear model.
|
||||
# Try to find values for W and b that compute y_data = W * x_data + b
|
||||
# (We know that W should be 0.1 and b 0.3, but Tensorflow will
|
||||
# figure that out for us.)
|
||||
W = tf.Variable(tf.random_uniform([1], -1.0, 1.0))
|
||||
b = tf.Variable(tf.zeros([1]))
|
||||
W = tf.Variable(tf.random_uniform([1, 2], -1.0, 1.0))
|
||||
y = tf.matmul(W, x_data) + b
|
||||
y = W * x_data + b
|
||||
|
||||
# Minimize the squared errors.
|
||||
# Minimize the mean squared errors.
|
||||
loss = tf.reduce_mean(tf.square(y - y_data))
|
||||
optimizer = tf.train.GradientDescentOptimizer(0.5)
|
||||
train = optimizer.minimize(loss)
|
||||
|
||||
# For initializing the variables.
|
||||
# Before starting, initialize the variables. We will 'run' this first.
|
||||
init = tf.initialize_all_variables()
|
||||
|
||||
# Launch the graph
|
||||
# Launch the graph.
|
||||
sess = tf.Session()
|
||||
sess.run(init)
|
||||
|
||||
# Fit the plane.
|
||||
for step in xrange(0, 201):
|
||||
# Fit the line.
|
||||
for step in xrange(201):
|
||||
sess.run(train)
|
||||
if step % 20 == 0:
|
||||
print step, sess.run(W), sess.run(b)
|
||||
|
||||
# Learns best fit is W: [[0.100 0.200]], b: [0.300]
|
||||
# Learns best fit is W: [0.1], b: [0.3]
|
||||
```
|
||||
|
||||
The first part of this code builds the data flow graph. TensorFlow does not
|
||||
|
@ -1,133 +1,189 @@
|
||||
# Download and Setup
|
||||
|
||||
You can install TensorFlow using our provided binary packages or from source.
|
||||
You can install TensorFlow either from our provided binary packages or from the
|
||||
github source.
|
||||
|
||||
## Binary Installation
|
||||
## Requirements
|
||||
|
||||
The TensorFlow Python API currently requires Python 2.7: we are
|
||||
[working](https://github.com/tensorflow/tensorflow/issues/1) on adding support
|
||||
for Python 3.
|
||||
The TensorFlow Python API currently requires Python 2.7. We are
|
||||
[adding support for Python 3](https://github.com/tensorflow/tensorflow/issues/1).
|
||||
|
||||
The simplest way to install TensorFlow is using
|
||||
[pip](https://pypi.python.org/pypi/pip) for both Linux and Mac.
|
||||
The GPU version (Linux only) currently requires the Cuda Toolkit 7.0 and CUDNN
|
||||
6.5 V2. Please see [Cuda installation](#install_cuda).
|
||||
|
||||
## Overview
|
||||
|
||||
We support different ways to install TensorFlow:
|
||||
|
||||
* [Pip install](#pip_install): Install TensorFlow on your machine, possibly
|
||||
upgrading previously installed Python packages. May impact existing
|
||||
Python programs on your machine.
|
||||
* [Virtualenv install](#virtualenv_install): Install TensorFlow in its own
|
||||
directory, not impacting any existing Python programs on your machine.
|
||||
* [Docker install](#docker_install): Run TensorFlow in a Docker container
|
||||
isolated from all other programs on your machine.
|
||||
|
||||
If you are familiar with Pip, Virtualenv, or Docker, please feel free to adapt
|
||||
the instructions to your particular needs. The names of the pip and Docker
|
||||
images are listed in the corresponding installation sections.
|
||||
|
||||
If you encounter installation errors, see
|
||||
[common problems](#common_install_problems) for some solutions. To simplify
|
||||
installation, please consider using our virtualenv-based instructions
|
||||
[here](#virtualenv_install).
|
||||
[common problems](#common_install_problems) for some solutions.
|
||||
|
||||
### Ubuntu/Linux 64-bit
|
||||
## Pip Installation {#pip_install}
|
||||
|
||||
[Pip](https://en.wikipedia.org/wiki/Pip_(package_manager)) is a package
|
||||
management system used to install and manage software packages written in
|
||||
Python.
|
||||
|
||||
The packages that will be installed or upgraded during the pip install are listed in the
|
||||
[REQUIRED_PACKAGES section of setup.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/pip_package/setup.py)
|
||||
|
||||
Install pip if not already installed:
|
||||
|
||||
```bash
|
||||
# For CPU-only version
|
||||
$ pip install https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.5.0-cp27-none-linux_x86_64.whl
|
||||
# Ubuntu/Linux 64-bit
|
||||
$ sudo apt-get install python-pip python-dev
|
||||
|
||||
# For GPU-enabled version (only install this version if you have the CUDA sdk installed)
|
||||
$ pip install https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.5.0-cp27-none-linux_x86_64.whl
|
||||
# Mac OS X
|
||||
$ sudo easy_install pip
|
||||
```
|
||||
|
||||
### Mac OS X
|
||||
|
||||
On OS X, we recommend installing [homebrew](http://brew.sh) and `brew install
|
||||
python` before proceeding, or installing TensorFlow within [virtualenv](#virtualenv_install).
|
||||
Install TensorFlow:
|
||||
|
||||
```bash
|
||||
# Only CPU-version is available at the moment.
|
||||
$ pip install https://storage.googleapis.com/tensorflow/mac/tensorflow-0.5.0-py2-none-any.whl
|
||||
# Ubuntu/Linux 64-bit, CPU only:
|
||||
$ sudo pip install --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.5.0-cp27-none-linux_x86_64.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, GPU enabled:
|
||||
$ sudo pip install --upgrade https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.5.0-cp27-none-linux_x86_64.whl
|
||||
|
||||
# Mac OS X, CPU only:
|
||||
$ sudo easy_install --upgrade six
|
||||
$ sudo pip install --upgrade https://storage.googleapis.com/tensorflow/mac/tensorflow-0.5.0-py2-none-any.whl
|
||||
```
|
||||
|
||||
## Docker-based installation
|
||||
You can now [test your installation](#test_install).
|
||||
|
||||
We also support running TensorFlow via [Docker](http://docker.com/), which lets
|
||||
you avoid worrying about setting up dependencies.
|
||||
## Virtualenv installation {#virtualenv_install}
|
||||
|
||||
First, [install Docker](http://docs.docker.com/engine/installation/). Once
|
||||
Docker is up and running, you can start a container with one command:
|
||||
[Virtualenv](http://docs.python-guide.org/en/latest/dev/virtualenvs/) is a tool
|
||||
to keep the dependencies required by different Python projects in separate
|
||||
places. The Virtualenv installation of TensorFlow will not override
|
||||
pre-existing version of the Python packages needed by TensorFlow.
|
||||
|
||||
With [Virtualenv](https://pypi.python.org/pypi/virtualenv) the installation is
|
||||
as follows:
|
||||
|
||||
* Install pip and Virtualenv.
|
||||
* Create a Virtualenv environment.
|
||||
* Activate the Virtualenv environment and install TensorFlow in it.
|
||||
* After the install you will activate the Virtualenv environment each time you
|
||||
want to use TensorFlow.
|
||||
|
||||
Install pip and Virtualenv:
|
||||
|
||||
```bash
|
||||
# Ubuntu/Linux 64-bit
|
||||
$ sudo apt-get install python-pip python-dev python-virtualenv
|
||||
|
||||
# Mac OS X
|
||||
$ sudo easy_install pip
|
||||
$ sudo pip install --upgrade virtualenv
|
||||
```
|
||||
|
||||
Create a Virtualenv environment in the directory `~/tensorflow`:
|
||||
|
||||
```bash
|
||||
$ virtualenv --system-site-packages ~/tensorflow
|
||||
```
|
||||
|
||||
Activate the environment and use pip to install TensorFlow inside it:
|
||||
|
||||
```bash
|
||||
$ source ~/tensorflow/bin/activate # If using bash
|
||||
$ source ~/tensorflow/bin/activate.csh # If using csh
|
||||
(tensorflow)$ # Your prompt should change
|
||||
|
||||
# Ubuntu/Linux 64-bit, CPU only:
|
||||
(tensorflow)$ pip install --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.5.0-cp27-none-linux_x86_64.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, GPU enabled:
|
||||
(tensorflow)$ pip install --upgrade https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.5.0-cp27-none-linux_x86_64.whl
|
||||
|
||||
# Mac OS X, CPU only:
|
||||
(tensorflow)$ pip install --upgrade https://storage.googleapis.com/tensorflow/mac/tensorflow-0.5.0-py2-none-any.whl
|
||||
```
|
||||
|
||||
With the Virtualenv environment activated, you can now
|
||||
[test your installation](#test_install).
|
||||
|
||||
```bash
|
||||
# When you are done using TensorFlow, deactivate the environment.
|
||||
(tensorflow)$ deactivate
|
||||
|
||||
$ # Your prompt should change back
|
||||
```
|
||||
|
||||
To use TensorFlow later you will have to activate the Virtualenv environment again:
|
||||
|
||||
```bash
|
||||
$ source ~/tensorflow/bin/activate # If using bash.
|
||||
$ source ~/tensorflow/bin/activate.csh # If using csh.
|
||||
(tensorflow)$ # Your prompt should change.
|
||||
# Run Python programs that use TensorFlow.
|
||||
...
|
||||
# When you are done using TensorFlow, deactivate the environment.
|
||||
(tensorflow)$ deactivate
|
||||
```
|
||||
|
||||
## Docker installation {#docker_install}
|
||||
|
||||
[Docker](http://docker.com/) is a system to build self contained versions of a
|
||||
Linux operating system running on your machine. When you install and run
|
||||
TensorFlow via Docker it completely isolates the installation from pre-existing
|
||||
packages on your machine.
|
||||
|
||||
We provide 2 Docker images:
|
||||
|
||||
* `b.gcr.io/tensorflow/tensorflow`: TensorFlow CPU binary image.
|
||||
* `b.gcr.io/tensorflow/tensorflow-full`: CPU Binary image plus source code.
|
||||
|
||||
With Docker the installation is as follows:
|
||||
|
||||
* Install Docker on your machine.
|
||||
* Launch a Docker container with the TensorFlow image. The image
|
||||
gets downloaded automatically on first launch.
|
||||
|
||||
See [installing Docker](http://docs.docker.com/engine/installation/) for instructions
|
||||
on installing Docker on your machine.
|
||||
|
||||
Also create a [Docker
|
||||
group](http://docs.docker.com/engine/installation/ubuntulinux/#create-a-docker-group)
|
||||
to allow launching containers without `sudo`.
|
||||
|
||||
After Docker is installed, launch a Docker container with the TensorFlow binary
|
||||
image as follows.
|
||||
|
||||
```bash
|
||||
$ docker run -it b.gcr.io/tensorflow/tensorflow
|
||||
```
|
||||
|
||||
This will start a container with TensorFlow and all its dependencies already
|
||||
installed.
|
||||
Within the Docker container, you can now [test your installation](#test_install).
|
||||
|
||||
### Additional images
|
||||
|
||||
The default Docker image above contains just a minimal set of libraries for
|
||||
getting up and running with TensorFlow. We also have the following container,
|
||||
which you can use in the `docker run` command above:
|
||||
|
||||
* `b.gcr.io/tensorflow/tensorflow-full`: Contains a complete TensorFlow source
|
||||
installation, including all utilities needed to build and run TensorFlow. This
|
||||
makes it easy to experiment directly with the source, without needing to
|
||||
install any of the dependencies described above.
|
||||
|
||||
## VirtualEnv-based installation {#virtualenv_install}
|
||||
|
||||
We recommend using [virtualenv](https://pypi.python.org/pypi/virtualenv) to
|
||||
create an isolated container and install TensorFlow in that container -- it is
|
||||
optional but makes verifying installation issues easier.
|
||||
|
||||
First, install all required tools:
|
||||
You can alternatively launch the TensorFlow source image, for example if you want
|
||||
to experiment directly with the source.
|
||||
|
||||
```bash
|
||||
# On Linux:
|
||||
$ sudo apt-get install python-pip python-dev python-virtualenv
|
||||
|
||||
# On Mac:
|
||||
$ sudo easy_install pip # If pip is not already installed
|
||||
$ sudo pip install --upgrade virtualenv
|
||||
$ docker run -it b.gcr.io/tensorflow/tensorflow-full
|
||||
```
|
||||
|
||||
Next, set up a new virtualenv environment. To set it up in the
|
||||
directory `~/tensorflow`, run:
|
||||
## Test the TensorFlow installation {#test_install}
|
||||
|
||||
```bash
|
||||
$ virtualenv --system-site-packages ~/tensorflow
|
||||
$ cd ~/tensorflow
|
||||
```
|
||||
### (Optional, Linux) Enable GPU Support
|
||||
|
||||
Then activate the virtualenv:
|
||||
|
||||
```bash
|
||||
$ source bin/activate # If using bash
|
||||
$ source bin/activate.csh # If using csh
|
||||
(tensorflow)$ # Your prompt should change
|
||||
```
|
||||
|
||||
Inside the virtualenv, install TensorFlow:
|
||||
|
||||
```bash
|
||||
# For CPU-only linux x86_64 version
|
||||
(tensorflow)$ pip install --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.5.0-cp27-none-linux_x86_64.whl
|
||||
|
||||
# For GPU-enabled linux x86_64 version
|
||||
(tensorflow)$ pip install --upgrade https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.5.0-cp27-none-linux_x86_64.whl
|
||||
|
||||
# For Mac CPU-only version
|
||||
(tensorflow)$ pip install --upgrade https://storage.googleapis.com/tensorflow/mac/tensorflow-0.5.0-py2-none-any.whl
|
||||
```
|
||||
|
||||
Make sure you have downloaded the source code for TensorFlow, and then you can
|
||||
then run an example TensorFlow program like:
|
||||
|
||||
```bash
|
||||
(tensorflow)$ cd tensorflow/models/image/mnist
|
||||
(tensorflow)$ python convolutional.py
|
||||
|
||||
# When you are done using TensorFlow:
|
||||
(tensorflow)$ deactivate # Deactivate the virtualenv
|
||||
|
||||
$ # Your prompt should change back
|
||||
```
|
||||
|
||||
## Try your first TensorFlow program
|
||||
|
||||
### (Optional) Enable GPU Support
|
||||
|
||||
If you installed the GPU-enabled TensorFlow pip binary, you must have the
|
||||
correct versions of the CUDA SDK and CUDNN installed on your
|
||||
system. Please see [the CUDA installation instructions](#install_cuda).
|
||||
If you installed the GPU version of TensorFlow, you must also install the Cuda
|
||||
Toolkit 7.0 and CUDNN 6.5 V2. Please see [Cuda installation](#install_cuda).
|
||||
|
||||
You also need to set the `LD_LIBRARY_PATH` and `CUDA_HOME` environment
|
||||
variables. Consider adding the commands below to your `~/.bash_profile`. These
|
||||
@ -138,13 +194,15 @@ export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/cuda/lib64"
|
||||
export CUDA_HOME=/usr/local/cuda
|
||||
```
|
||||
|
||||
### Run TensorFlow
|
||||
### Run TensorFlow from the Command Line
|
||||
|
||||
Open a python terminal:
|
||||
See [common problems](#common_install_problems) if some error happens.
|
||||
|
||||
Open a terminal and type the following:
|
||||
|
||||
```bash
|
||||
$ python
|
||||
|
||||
...
|
||||
>>> import tensorflow as tf
|
||||
>>> hello = tf.constant('Hello, TensorFlow!')
|
||||
>>> sess = tf.Session()
|
||||
@ -152,10 +210,43 @@ $ python
|
||||
Hello, TensorFlow!
|
||||
>>> a = tf.constant(10)
|
||||
>>> b = tf.constant(32)
|
||||
>>> print sess.run(a+b)
|
||||
>>> print sess.run(a + b)
|
||||
42
|
||||
>>>
|
||||
```
|
||||
|
||||
### Run a TensorFlow demo model
|
||||
|
||||
All TensorFlow packages, including the demo models, are installed in the Python library.
|
||||
The exact location of the Python library depends on your system, but is usually one of:
|
||||
|
||||
```bash
|
||||
/usr/local/lib/python2.7/dist-packages/tensorflow
|
||||
/usr/local/lib/python2.7/site-packages/tensorflow
|
||||
```
|
||||
|
||||
You can find out the directory with the following command:
|
||||
|
||||
```bash
|
||||
$ python -c 'import site; print "\n".join(site.getsitepackages())'
|
||||
```
|
||||
|
||||
The simple demo model for classifying handwritten digits from the MNIST dataset
|
||||
is in the sub-directory `models/image/mnist/convolutional.py`. You can run it from the command
|
||||
line as follows:
|
||||
|
||||
```bash
|
||||
# Using 'python -m' to find the program in the python search path:
|
||||
$ python -m tensorflow.models.image.mnist.convolutional
|
||||
Extracting data/train-images-idx3-ubyte.gz
|
||||
Extracting data/train-labels-idx1-ubyte.gz
|
||||
Extracting data/t10k-images-idx3-ubyte.gz
|
||||
Extracting data/t10k-labels-idx1-ubyte.gz
|
||||
...etc...
|
||||
|
||||
# You can alternatively pass the path to the model program file to the python interpreter.
|
||||
$ python /usr/local/lib/python2.7/dist-packages/tensorflow/models/image/mnist/convolutional.py
|
||||
...
|
||||
```
|
||||
|
||||
## Installing from sources {#source}
|
||||
@ -427,39 +518,39 @@ SyntaxError: invalid syntax
|
||||
|
||||
Solution: make sure you are using Python 2.7.
|
||||
|
||||
### On MacOSX
|
||||
### Mac OS X: ImportError: No module named copyreg
|
||||
|
||||
|
||||
If you encounter:
|
||||
On Mac OS X, you may encounter the following when importing tensorflow.
|
||||
|
||||
```python
|
||||
import six.moves.copyreg as copyreg
|
||||
|
||||
>>> import tensorflow as tf
|
||||
...
|
||||
ImportError: No module named copyreg
|
||||
```
|
||||
|
||||
Solution: TensorFlow depends on protobuf, which requires `six-1.10.0`. Apple's
|
||||
default python environment has `six-1.4.1` and may be difficult to upgrade.
|
||||
There are several ways to fix this:
|
||||
Solution: TensorFlow depends on protobuf, which requires the Python package
|
||||
`six-1.10.0`. Apple's default Python installation only provides `six-1.4.1`.
|
||||
|
||||
1. Upgrade the system-wide copy of `six`:
|
||||
You can resolve the issue in one of the following ways:
|
||||
|
||||
```bash
|
||||
sudo easy_install -U six
|
||||
```
|
||||
* Upgrade the Python installation with the current version `six`:
|
||||
|
||||
2. Install a separate copy of python via homebrew:
|
||||
```bash
|
||||
$ sudo easy_install -U six
|
||||
```
|
||||
|
||||
```bash
|
||||
brew install python
|
||||
```
|
||||
* Install TensorFlow with a separate Python library:
|
||||
|
||||
3. Build or use TensorFlow
|
||||
[within `virtualenv`](#virtualenv_install).
|
||||
* Using [Virtualenv](#virtualenv_install).
|
||||
* Using [Docker](#docker_install).
|
||||
|
||||
* Install a separate copy of Python via [Homebrew](http://brew.sh/) or
|
||||
[MacPorts](https://www.macports.org/) and re-install TensorFlow in that
|
||||
copy of Python.
|
||||
|
||||
# Mac OS X: TypeError: `__init__()` got an unexpected keyword argument 'syntax'
|
||||
|
||||
If you encounter:
|
||||
On Mac OS X, you may encounter the following when importing tensorflow.
|
||||
|
||||
```
|
||||
>>> import tensorflow as tf
|
||||
@ -480,5 +571,5 @@ The best current solution is to make sure older versions of protobuf are not
|
||||
installed, such as:
|
||||
|
||||
```bash
|
||||
brew reinstall --devel protobuf
|
||||
$ pip install --upgrade protobuf
|
||||
```
|
||||
|
Before ![]() (image error) Size: 285 KiB |
@ -117,7 +117,7 @@ Python op wrappers are created automatically in
|
||||
`bazel-genfiles/tensorflow/python/ops/gen_user_ops.py` for all ops placed in the
|
||||
[`tensorflow/core/user_ops`][user_ops] directory when you build Tensorflow.
|
||||
|
||||
> Note: The generated function will be given a snake_case name (to comply with
|
||||
> Note: The generated function will be given a snake\_case name (to comply with
|
||||
> [PEP8](https://www.python.org/dev/peps/pep-0008/)). So if your op is named
|
||||
> `ZeroOut` in the C++ files, the python function will be called `zero_out`.
|
||||
|
||||
@ -294,7 +294,7 @@ which can then be used in the `Compute` method:
|
||||
<code class="lang-c++"><pre>
|
||||
void Compute(OpKernelContext\* context) override {
|
||||
// ...
|
||||
<br/> <b>// Check that preserve_index is in range
|
||||
<br/> <b>// Check that preserve\_index is in range
|
||||
OP\_REQUIRES(context, preserve\_index_ < input.dimension(0),
|
||||
errors::InvalidArgument("preserve\_index out of range"));<br/>
|
||||
</b>// Set all the elements of the output tensor to 0
|
||||
@ -314,7 +314,7 @@ which can then be used in the `Compute` method:
|
||||
> <code class="lang-c++"><pre>
|
||||
> REGISTER\_OP("ZeroOut")
|
||||
> <b>.Attr("preserve\_index: int = 0")</b>
|
||||
> .Input("to_zero: int32")
|
||||
> .Input("to\_zero: int32")
|
||||
> .Output("zeroed: int32");
|
||||
> </pre></code>
|
||||
|
||||
@ -449,7 +449,7 @@ in addition to `int32`s, your Op registration might look like:
|
||||
<code class="lang-c++"><pre>
|
||||
REGISTER\_OP("ZeroOut")
|
||||
<b>.Attr("T: {float, int32}")</b>
|
||||
.Input("to_zero: <b>T</b>")
|
||||
.Input("to\_zero: <b>T</b>")
|
||||
.Output("zeroed: <b>T</b>");
|
||||
</pre></code>
|
||||
|
||||
@ -457,7 +457,7 @@ Your Op registration now specifies that the input's type must be `float`, or
|
||||
`int32`, and that its output will be the same type, since both have type `T`.
|
||||
|
||||
> A note on naming:{#naming} Inputs, outputs, and attrs generally should be
|
||||
> given snake_case names. The one exception is attrs that are used as the type
|
||||
> given snake\_case names. The one exception is attrs that are used as the type
|
||||
> of an input or in the type of an input. Those attrs can be inferred when the
|
||||
> op is added to the graph and so don't appear in the op's function. For
|
||||
> example, this last definition of ZeroOut will generate a Python function that
|
||||
@ -560,7 +560,7 @@ REGISTER\_KERNEL\_BUILDER(
|
||||
> <code class="lang-c++"><pre>
|
||||
> REGISTER\_OP("ZeroOut")
|
||||
> <b>.Attr("T: {float, int32} = DT_INT32")</b>
|
||||
> .Input("to_zero: T")
|
||||
> .Input("to\_zero: T")
|
||||
> .Output("zeroed: T")
|
||||
> </pre></code>
|
||||
|
||||
@ -569,7 +569,7 @@ Lets say you wanted to add more types, say `double`:
|
||||
<code class="lang-c++"><pre>
|
||||
REGISTER\_OP("ZeroOut")
|
||||
<b>.Attr("T: {float, <b>double,</b> int32}")</b>
|
||||
.Input("to_zero: <b>T</b>")
|
||||
.Input("to\_zero: <b>T</b>")
|
||||
.Output("zeroed: <b>T</b>");
|
||||
</pre></code>
|
||||
|
||||
@ -888,7 +888,7 @@ There are several ways to preserve backwards-compatibility.
|
||||
type into a list of varying types).
|
||||
|
||||
The full list of safe and unsafe changes can be found in
|
||||
[tensorflow/core/framework/op_compatibility_test.cc](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/op_compatibility_test.cc).
|
||||
[`tensorflow/core/framework/op_compatibility_test.cc`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/op_compatibility_test.cc).
|
||||
If you cannot make your change to an operation backwards compatible, then create
|
||||
a new operation with a new name with the new semantics.
|
||||
|
||||
|
Before ![]() (image error) Size: 50 KiB |
Before ![]() (image error) Size: 46 KiB |
Before ![]() (image error) Size: 334 B |
Before ![]() (image error) Size: 380 B |
Before ![]() (image error) Size: 5.8 KiB |
Before ![]() (image error) Size: 317 B |
Before ![]() (image error) Size: 3.2 MiB |
Before ![]() (image error) Size: 669 B |
Before ![]() (image error) Size: 29 KiB |
Before ![]() (image error) Size: 29 KiB |
Before ![]() (image error) Size: 488 B |
Before ![]() (image error) Size: 534 B |