TensorFlow: more features, performance improvements, and doc fixes.

Changes:
- Add Split/Concat() methods to TensorUtil (meant for convenience, not
  speed) by Chris.

- Changes to linear algebra ops interface by Rasmus

- Tests for tensorboard by Daniel

- Fix bug in histogram calculation by Cassandra

- Added tool for backwards compatibility of OpDefs.  Tool
  Checks in history of opdefs and their changes, checks for
  backwards-incompatible changes.  All done by @josh11b

- Fix some protobuf example proto docs by Oliver

- Add derivative of MatrixDeterminant by @yaroslavvb

- Add a priority queue queue by @ebrevdo

- Doc and typo fixes by Aurelien and @dave-andersen

- Speed improvements to ConvBackwardFilter by @andydavis

- Improve speed of Alexnet on TitanX by @zheng-xq

- Add some host memory annotations to some GPU kernels by Yuan.

- Add support for doubles in histogram summary by @jmchen-g

Base CL: 108158338
This commit is contained in:
Vijay Vasudevan 2015-11-18 10:47:35 -08:00
parent 9eb88d56ab
commit ab34d55ce7
111 changed files with 11219 additions and 2743 deletions

View File

@ -11,29 +11,29 @@ package tensorflow;
// features {
// feature {
// key: "age"
// float_list {
// value { float_list {
// value: 29.0
// }
// }}
// }
// feature {
// key: "movie"
// bytes_list {
// value { bytes_list {
// value: "The Shawshank Redemption"
// value: "Fight Club"
// }
// }}
// }
// feature {
// key: "movie_ratings"
// float_list {
// value { float_list {
// value: 9.0
// value: 9.7
// }
// }}
// }
// feature {
// key: "suggestion"
// bytes_list {
// value { bytes_list {
// value: "Inception"
// }
// }}
// }
// # Note that this feature exists to be used as a label in training.
// # E.g., if training a logistic regression model to predict purchase
@ -41,9 +41,9 @@ package tensorflow;
// # "suggestion_purchased".
// feature {
// key: "suggestion_purchased"
// float_list {
// value { float_list {
// value: 1.0
// }
// }}
// }
// # Similar to "suggestion_purchased" above this feature exists to be used
// # as a label in training.
@ -52,9 +52,9 @@ package tensorflow;
// # "purchase_price".
// feature {
// key: "purchase_price"
// float_list {
// value { float_list {
// value: 9.99
// }
// }}
// }
// }
//

View File

@ -14,41 +14,41 @@
// Example Features for a movie recommendation application:
// feature {
// key: "age"
// float_list {
// value { float_list {
// value: 29.0
// }
// }}
// }
// feature {
// key: "movie"
// bytes_list {
// value { bytes_list {
// value: "The Shawshank Redemption"
// value: "Fight Club"
// }
// }}
// }
// feature {
// key: "movie_ratings"
// float_list {
// value { float_list {
// value: 9.0
// value: 9.7
// }
// }}
// }
// feature {
// key: "suggestion"
// bytes_list {
// value { bytes_list {
// value: "Inception"
// }
// }}
// }
// feature {
// key: "suggestion_purchased"
// int64_list {
// value { int64_list {
// value: 1
// }
// }}
// }
// feature {
// key: "purchase_price"
// float_list {
// value { float_list {
// value: 9.99
// }
// }}
// }
syntax = "proto3";

View File

@ -24,5 +24,107 @@ Tensor DeepCopy(const Tensor& other) {
return tmp;
}
Tensor Concat(const gtl::ArraySlice<Tensor>& tensors) {
CHECK_GT(tensors.size(), 0);
int64 total_dim0_size = 0;
for (const Tensor& tensor : tensors) {
CHECK_GT(tensor.dims(), 0);
total_dim0_size += tensor.dim_size(0);
}
TensorShape shape = tensors[0].shape();
shape.set_dim(0, total_dim0_size);
Tensor result = Tensor(tensors[0].dtype(), shape);
// We use StringPiece as a convenient map over the tensor buffer,
// but we cast the type to get to the underlying buffer to do the
// copy.
StringPiece to_data = result.tensor_data();
if (DataTypeCanUseMemcpy(result.dtype())) {
int64 offset = 0;
for (const Tensor& tensor : tensors) {
StringPiece from_data = tensor.tensor_data();
CHECK_LE(offset + from_data.size(), to_data.size());
memcpy(const_cast<char*>(to_data.data()) + offset, from_data.data(),
from_data.size());
offset += from_data.size();
}
} else {
CHECK_EQ(DT_STRING, result.dtype());
string* to_strings =
reinterpret_cast<string*>(const_cast<char*>(to_data.data()));
int64 offset = 0;
for (const Tensor& tensor : tensors) {
auto from_strings = tensor.flat<string>();
CHECK_LE(offset + tensor.NumElements(), result.NumElements());
for (int i = 0; i < tensor.NumElements(); ++i) {
to_strings[offset + i] = from_strings(i);
}
offset += tensor.NumElements();
}
}
return result;
}
std::vector<Tensor> Split(const Tensor& tensor,
const gtl::ArraySlice<int64>& sizes) {
CHECK_GT(tensor.dims(), 0);
int64 total_size = 0;
for (int64 size : sizes) {
total_size += size;
}
CHECK_EQ(total_size, tensor.dim_size(0));
std::vector<Tensor> result;
StringPiece from_data = tensor.tensor_data();
if (DataTypeCanUseMemcpy(tensor.dtype())) {
int64 offset = 0;
for (int64 size : sizes) {
TensorShape shape = tensor.shape();
shape.set_dim(0, size);
result.emplace_back(tensor.dtype(), shape);
Tensor* split = &result[result.size() - 1];
// We use StringPiece as a convenient map over the tensor buffer,
// but we cast the type to get to the underlying buffer to do the
// copy.
StringPiece to_data = split->tensor_data();
CHECK_LE(offset + to_data.size(), from_data.size());
memcpy(const_cast<char*>(to_data.data()), from_data.data() + offset,
to_data.size());
offset += to_data.size();
}
} else {
CHECK_EQ(DT_STRING, tensor.dtype());
auto from_strings = tensor.flat<string>();
int64 offset = 0;
for (int64 size : sizes) {
TensorShape shape = tensor.shape();
shape.set_dim(0, size);
result.emplace_back(tensor.dtype(), shape);
Tensor& split = result[result.size() - 1];
string* to_strings = reinterpret_cast<string*>(
const_cast<char*>(split.tensor_data().data()));
CHECK_LE(offset + split.NumElements(), tensor.NumElements());
for (int i = 0; i < split.NumElements(); ++i) {
to_strings[i] = from_strings(offset + i);
}
offset += split.NumElements();
}
}
return result;
}
} // namespace tensor
} // namespace tensorflow

View File

@ -15,6 +15,28 @@ namespace tensor {
// 'other' is not appropriately memory-aligned.
Tensor DeepCopy(const Tensor& other);
// Concatenates 'tensors' into a single tensor, along their 0th dimension.
//
// REQUIRES: All members of 'tensors' must have the same data type parameter.
// REQUIRES: Each member of 'tensors' must have at least one dimension.
// REQUIRES: Each member of 'tensors' must point to data stored in CPU memory.
// REQUIRES: Each member of 'tensors' must be a Tensor of a copy-able type if it
// is not appropriately memory-aligned.
Tensor Concat(const gtl::ArraySlice<Tensor>& tensors);
// Splits 'tensor' into 'sizes.size()' individual tensors, along the 0th
// dimension. The ith output tensor has 0th-dimension size 'sizes[i]'.
//
// REQUIRES: 'tensor' must have at least one dimension.
// REQUIRES: 'tensor.dim_size(0)' must equal the sum of the elements of 'sizes'.
// REQUIRES: 'tensor' must point to data stored in CPU memory.
// REQUIRES: 'tensor' must be a Tensor of a copy-able type if it is not
// appropriately memory-aligned.
//
// Split() and Concat() are inverse operations.
std::vector<Tensor> Split(const Tensor& tensor,
const gtl::ArraySlice<int64>& sizes);
} // namespace tensor
} // namespace tensorflow

View File

@ -120,5 +120,81 @@ TEST(TensorUtil, DeepCopySlice) {
}
}
TEST(TensorUtil, Concat) {
std::vector<int64> sizes = {1, 4, 5};
std::vector<Tensor> to_concat;
int64 total_size = 0;
int offset = 0;
for (int entry = 0; entry < sizes.size(); ++entry) {
const int64 size = sizes[entry];
Tensor tensor(DT_INT32, TensorShape({size, 2}));
for (int i = offset; i < offset + size; ++i) {
for (int j = 0; j < 2; ++j) {
tensor.matrix<int32>()(i - offset, j) = 2 * i + j;
}
}
to_concat.push_back(tensor);
total_size += size;
offset += size;
}
Tensor concated = tensor::Concat(to_concat);
ASSERT_EQ(TensorShape({total_size, 2}), concated.shape());
for (int i = 0; i < total_size; ++i) {
for (int j = 0; j < 2; ++j) {
EXPECT_EQ(2 * i + j, concated.matrix<int32>()(i, j));
}
}
}
TEST(TensorUtil, Split) {
Tensor to_split(DT_INT64, TensorShape({10, 2}));
for (int i = 0; i < 10; ++i) {
for (int j = 0; j < 2; ++j) {
to_split.matrix<int64>()(i, j) = 2 * i + j;
}
}
std::vector<int64> sizes = {1, 4, 5};
std::vector<Tensor> splits = tensor::Split(to_split, sizes);
ASSERT_EQ(sizes.size(), splits.size());
int offset = 0;
for (int entry = 0; entry < splits.size(); ++entry) {
const int64 size = sizes[entry];
const Tensor& split = splits[entry];
ASSERT_EQ(TensorShape({size, 2}), split.shape());
for (int i = offset; i < offset + size; ++i) {
for (int j = 0; j < 2; ++j) {
EXPECT_EQ(2 * i + j, split.matrix<int64>()(i - offset, j));
}
}
offset += size;
}
}
TEST(TensorUtil, ConcatSplitStrings) {
Tensor x(DT_STRING, TensorShape({4, 3}));
for (int i = 0; i < 4 * 3; ++i) {
x.flat<string>()(i) = strings::StrCat("foo_", i);
}
Tensor x_round_tripped = tensor::Concat(tensor::Split(x, {2, 1, 1}));
ASSERT_EQ(x.shape(), x_round_tripped.shape());
for (int i = 0; i < 4 * 3; ++i) {
EXPECT_EQ(x.flat<string>()(i), x_round_tripped.flat<string>()(i));
}
// Ensure that no memory is being shared between 'x' and 'x_round_tripped'.
for (int i = 0; i < 4 * 3; ++i) {
x_round_tripped.flat<string>()(i) = strings::StrCat("bar_", i);
}
for (int i = 0; i < 4 * 3; ++i) {
EXPECT_NE(x.flat<string>()(i), x_round_tripped.flat<string>()(i));
}
}
} // namespace
} // namespace tensorflow

View File

@ -5,7 +5,12 @@
#include <set>
#include <string>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
// Disable clang-format to prevent 'FixedPoint' header from being included
// before 'Tensor' header on which it depends.
// clang-format off
#include "third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint"
// clang-format on
#include "tensorflow/core/framework/bfloat16.h"
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/types.pb.h"

View File

@ -2,14 +2,49 @@
#define EIGEN_USE_GPU
#include <algorithm>
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/kernels/bias_op.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"
namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
// Definition of the GPU implementations declared in bias_op.cc.
namespace functor {
template <typename T>
__global__ void BiasOpCustomKernel(int nthreads, const T* input, const T* bias,
int bias_size, int replicate_count,
T* output) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
int bias_offset = index % bias_size;
output[index] = input[index] + bias[bias_offset];
}
}
template <typename T, int Dims>
struct Bias<GPUDevice, T, Dims> {
typedef GPUDevice Device;
// Add "bias" to "input", broadcasting it on all dimensions but the last one.
void operator()(const Device& d, typename TTypes<T, Dims>::ConstTensor input,
typename TTypes<T>::ConstVec bias,
typename TTypes<T, Dims>::Tensor output) {
const int bias_size = bias.dimension(0);
const int rest_size = input.size() / bias_size;
CudaLaunchConfig config = GetCudaLaunchConfig(output.size(), d);
BiasOpCustomKernel<<<config.block_count, config.thread_per_block, 0,
d.stream()>>>(config.virtual_thread_count,
input.data(), bias.data(), bias_size,
rest_size, output.data());
}
};
} // namespace functor
#define DEFINE_GPU_SPECS(T) \
template struct functor::Bias<GPUDevice, T, 2>; \
template struct functor::Bias<GPUDevice, T, 3>; \

View File

@ -16,10 +16,11 @@
namespace tensorflow {
template <class Scalar, bool SupportsBatchOperationT>
class CholeskyOp : public LinearAlgebraOp<Scalar, SupportsBatchOperationT> {
class CholeskyOp
: public UnaryLinearAlgebraOp<Scalar, SupportsBatchOperationT> {
public:
explicit CholeskyOp(OpKernelConstruction* context)
: LinearAlgebraOp<Scalar, SupportsBatchOperationT>(context) {}
: UnaryLinearAlgebraOp<Scalar, SupportsBatchOperationT>(context) {}
TensorShape GetOutputMatrixShape(
const TensorShape& input_matrix_shape) override {
@ -36,9 +37,10 @@ class CholeskyOp : public LinearAlgebraOp<Scalar, SupportsBatchOperationT> {
}
}
using typename LinearAlgebraOp<Scalar, SupportsBatchOperationT>::MatrixMap;
using
typename LinearAlgebraOp<Scalar, SupportsBatchOperationT>::ConstMatrixMap;
typename UnaryLinearAlgebraOp<Scalar, SupportsBatchOperationT>::MatrixMap;
using typename UnaryLinearAlgebraOp<Scalar,
SupportsBatchOperationT>::ConstMatrixMap;
void ComputeMatrix(OpKernelContext* context, const ConstMatrixMap& input,
MatrixMap* output) override {

View File

@ -184,6 +184,24 @@ struct PadInput {
}
};
template <typename Device, typename T>
struct NHWCToNCHW {
void operator()(const Device& d, typename TTypes<T, 4>::ConstTensor in,
typename TTypes<T, 4>::Tensor out);
};
template <typename Device, typename T>
struct NCHWToNHWC {
void operator()(const Device& d, typename TTypes<T, 4>::ConstTensor in,
typename TTypes<T, 4>::Tensor out);
};
template <typename Device, typename T>
struct ReverseTransformFilter {
void operator()(const Device& d, typename TTypes<T, 4>::ConstTensor in,
typename TTypes<T, 4>::Tensor out);
};
} // namespace functor
} // namespace tensorflow

View File

@ -15,6 +15,7 @@
#include "tensorflow/core/public/tensor_shape.h"
#include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/use_cudnn.h"
#include "tensorflow/core/util/work_sharder.h"
#if GOOGLE_CUDA
#include "tensorflow/stream_executor/stream.h"
@ -593,19 +594,29 @@ class Conv2DCustomBackpropFilterOp : public OpKernel {
contract_dims[0].first = 0;
contract_dims[0].second = 0;
auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
for (int image_id = 0; image_id < batch; image_id += shard_size) {
const int shard_limit = std::min(static_cast<int>(shard_size),
static_cast<int>(batch) - image_id);
for (int shard_id = 0; shard_id < shard_limit; ++shard_id) {
// TODO(andydavis) Parallelize this loop.
// When we compute the gradient with respect to the filters, we need
// to do im2col to allow gemm-type computation.
Im2col<T>(input_data, in_depth, input_rows, input_cols, filter_rows,
filter_cols, pad_top, pad_left, pad_bottom, pad_right, stride,
stride, col_buffer_data + shard_id * size_A);
input_data += input_offset;
}
auto shard = [&input_data, &col_buffer_data, &in_depth, &input_rows,
&input_cols, &filter_rows, &filter_cols, &pad_top,
&pad_left, &pad_bottom, &pad_right, &stride, &input_offset,
&size_A](int64 start, int64 limit) {
for (int shard_id = start; shard_id < limit; ++shard_id) {
auto input_data_shard = input_data + shard_id * input_offset;
auto col_data_shard = col_buffer_data + shard_id * size_A;
// When we compute the gradient with respect to the filters, we need
// to do im2col to allow gemm-type computation.
Im2col<T>(input_data_shard, in_depth, input_rows, input_cols,
filter_rows, filter_cols, pad_top, pad_left, pad_bottom,
pad_right, stride, stride, col_data_shard);
}
};
Shard(worker_threads.num_threads, worker_threads.workers, shard_limit,
size_A, shard);
ConstTensorMap A(col_buffer_data, output_image_size * shard_limit,
filter_total_size);
@ -615,6 +626,7 @@ class Conv2DCustomBackpropFilterOp : public OpKernel {
// Gradient with respect to filter.
C.device(context->eigen_cpu_device()) += A.contract(B, contract_dims);
input_data += input_offset * shard_limit;
out_backprop_data += output_offset * shard_limit;
}
}
@ -795,10 +807,9 @@ class Conv2DSlowBackpropInputOp : public OpKernel {
TensorShape({batch, out_depth, output_rows, output_cols}),
&transformed_out_backprop));
functor::TransformDepth<Device, T, int>()(
context->eigen_device<Device>(), To32Bit(out_backprop.tensor<T, 4>()),
Eigen::DSizes<int, 4>(0, 3, 1, 2),
To32Bit(transformed_out_backprop.tensor<T, 4>()));
functor::NHWCToNCHW<Device, T>()(context->eigen_device<Device>(),
out_backprop.tensor<T, 4>(),
transformed_out_backprop.tensor<T, 4>());
Tensor pre_transformed_in_backprop;
OP_REQUIRES_OK(context,
@ -831,12 +842,10 @@ class Conv2DSlowBackpropInputOp : public OpKernel {
}
auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
functor::TransformDepth<Device, T, int>()(
functor::NCHWToNHWC<Device, T>()(
context->eigen_device<Device>(),
To32Bit(toConstTensor(pre_transformed_in_backprop)
.template tensor<T, 4>()),
Eigen::DSizes<int, 4>(0, 2, 3, 1),
To32Bit(in_backprop->tensor<T, 4>()));
toConstTensor(pre_transformed_in_backprop).template tensor<T, 4>(),
in_backprop->tensor<T, 4>());
} else {
// We fill out a padded out_backprop
TensorShape padded_out_shape(
@ -1033,11 +1042,9 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
DataTypeToEnum<T>::value,
TensorShape({batch, out_depth, output_rows, output_cols}),
&transformed_out_backprop));
functor::TransformDepth<Device, T, int>()(
context->eigen_device<Device>(), To32Bit(out_backprop.tensor<T, 4>()),
Eigen::DSizes<int, 4>(0, 3, 1, 2),
To32Bit(transformed_out_backprop.tensor<T, 4>()));
functor::NHWCToNCHW<Device, T>()(context->eigen_device<Device>(),
out_backprop.tensor<T, 4>(),
transformed_out_backprop.tensor<T, 4>());
Tensor transformed_input;
OP_REQUIRES_OK(context,
@ -1045,11 +1052,9 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
DataTypeToEnum<T>::value,
TensorShape({batch, in_depth, input_rows, input_cols}),
&transformed_input));
functor::TransformDepth<Device, T, int>()(
context->eigen_device<Device>(), To32Bit(input.tensor<T, 4>()),
Eigen::DSizes<int, 4>(0, 3, 1, 2),
To32Bit(transformed_input.tensor<T, 4>()));
functor::NHWCToNCHW<Device, T>()(context->eigen_device<Device>(),
input.tensor<T, 4>(),
transformed_input.tensor<T, 4>());
auto out_backprop_ptr =
AsDeviceMemory(transformed_out_backprop.template flat<T>().data(),
@ -1075,12 +1080,11 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
}
auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
functor::TransformDepth<Device, T, int>()(
functor::ReverseTransformFilter<Device, T>()(
context->eigen_device<Device>(),
To32Bit(toConstTensor(pre_transformed_filter_backprop)
.template tensor<T, 4>()),
Eigen::DSizes<int, 4>(2, 3, 1, 0),
To32Bit(filter_backprop->tensor<T, 4>()));
toConstTensor(pre_transformed_filter_backprop)
.template tensor<T, 4>(),
filter_backprop->tensor<T, 4>());
} else {
// Fall back to the non-cudnn code path

View File

@ -2,25 +2,255 @@
#define EIGEN_USE_GPU
#include "tensorflow/core/kernels/conv_2d.h"
#include <algorithm>
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/kernels/conv_2d.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"
namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
namespace functor {
// A simple array that contains data that can be passed between CPU and GPU.
template <typename T, int IndexCount>
struct Array {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T& operator[](int index) const {
return data[index];
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T& operator[](int index) {
return data[index];
}
int data[IndexCount];
};
// A dimension type with compile-time known size.
template <int IndexCount>
struct Dimension : Array<int, IndexCount> {};
// An index type with compile-time known size.
template <int IndexCount>
struct Index : Array<int, IndexCount> {};
// A helper function that converts a tensor index into a flat array index.
template <int IndexCount>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int TensorIndexToFlat(
const Index<IndexCount>& index, const Dimension<IndexCount>& dims) {
int flat_index = index[0];
for (int i = 1; i < IndexCount; i++) {
flat_index = flat_index * dims[i] + index[i];
}
return flat_index;
}
// A helper function that converts a flat arrary index into a tensor index.
template <int IndexCount>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index<IndexCount> FlatToTensorIndex(
int index, const Dimension<IndexCount>& dims) {
Index<IndexCount> tensor_index;
for (int i = IndexCount - 1; i >= 0; i--) {
tensor_index[i] = index % dims[i];
index /= dims[i];
}
return tensor_index;
}
// A Cuda custom kernel that swaps dimension-0 and dimension-2 of a 3D tensor.
template <typename T>
__global__ void SwapDimension0And2InTensor3(int nthreads, const T* input,
Dimension<3> input_dims,
T* output) {
Dimension<3> output_dims;
output_dims[0] = input_dims[2];
output_dims[1] = input_dims[1];
output_dims[2] = input_dims[0];
CUDA_1D_KERNEL_LOOP(index, nthreads) {
int output_index = index;
Index<3> output_tensor_index = FlatToTensorIndex(output_index, output_dims);
Index<3> input_tensor_index;
input_tensor_index[0] = output_tensor_index[2];
input_tensor_index[1] = output_tensor_index[1];
input_tensor_index[2] = output_tensor_index[0];
int input_index = TensorIndexToFlat(input_tensor_index, input_dims);
output[output_index] = input[input_index];
}
}
// A Cuda custom kernel that swaps dimension-1 and dimension-2 of a 3D tensor.
template <typename T>
__global__ void SwapDimension1And2InTensor3(int nthreads, const T* input,
Dimension<3> input_dims,
T* output) {
Dimension<3> output_dims;
output_dims[0] = input_dims[0];
output_dims[1] = input_dims[2];
output_dims[2] = input_dims[1];
CUDA_1D_KERNEL_LOOP(index, nthreads) {
int output_index = index;
Index<3> output_tensor_index = FlatToTensorIndex(output_index, output_dims);
Index<3> input_tensor_index;
input_tensor_index[0] = output_tensor_index[0];
input_tensor_index[1] = output_tensor_index[2];
input_tensor_index[2] = output_tensor_index[1];
int input_index = TensorIndexToFlat(input_tensor_index, input_dims);
output[output_index] = input[input_index];
}
}
// A Cuda custom kernel that converst input to output, given proper padding on
// the left and the top. The padded value is zero.
template <typename T>
__global__ void PadInputCustomKernel(int nthreads, const T* input,
Dimension<4> input_dims, T* output,
Dimension<4> output_dims,
int padding_rows_left,
int padding_cols_left) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
int output_index = index;
Index<4> output_tensor_index = FlatToTensorIndex(output_index, output_dims);
Index<4> input_tensor_index;
input_tensor_index[0] = output_tensor_index[0];
input_tensor_index[1] = output_tensor_index[1] - padding_rows_left;
input_tensor_index[2] = output_tensor_index[2] - padding_cols_left;
input_tensor_index[3] = output_tensor_index[3];
if (input_tensor_index[1] >= 0 && input_tensor_index[1] < input_dims[1] &&
input_tensor_index[2] >= 0 && input_tensor_index[2] < input_dims[2]) {
int input_index = TensorIndexToFlat(input_tensor_index, input_dims);
output[output_index] = input[input_index];
} else {
output[output_index] = 0;
}
}
}
// A GPU helper function that converts TensorFlow filter format to Cudnn filter
// format.
template <typename T>
struct TransformFilter<GPUDevice, T, int> {
typedef GPUDevice Device;
void operator()(const Device& d, typename TTypes<T, 4, int>::ConstTensor in,
typename TTypes<T, 4, int>::Tensor out) {
Dimension<3> combined_dims;
combined_dims[0] = in.dimension(0) * in.dimension(1);
combined_dims[1] = in.dimension(2);
combined_dims[2] = in.dimension(3);
CudaLaunchConfig config = GetCudaLaunchConfig(out.size(), d);
SwapDimension0And2InTensor3<
T><<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
config.virtual_thread_count, in.data(), combined_dims, out.data());
}
};
// Converts Cudnn filter format back to TensorFlow filter format.
template <typename T>
struct ReverseTransformFilter<GPUDevice, T> {
typedef GPUDevice Device;
void operator()(const Device& d, typename TTypes<T, 4>::ConstTensor in,
typename TTypes<T, 4>::Tensor out) {
Dimension<3> combined_dims;
combined_dims[0] = in.dimension(0);
combined_dims[1] = in.dimension(1);
combined_dims[2] = in.dimension(2) * in.dimension(3);
CudaLaunchConfig config = GetCudaLaunchConfig(out.size(), d);
SwapDimension0And2InTensor3<
T><<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
config.virtual_thread_count, in.data(), combined_dims, out.data());
}
};
// A GPU helper function that converts input tensor to a larger output tensor,
// given proper padding values. The padded value is zero.
template <typename T>
struct PadInput<GPUDevice, T, int> {
typedef GPUDevice Device;
void operator()(const Device& d, typename TTypes<T, 4, int>::ConstTensor in,
int padding_rows_left, int padding_rows_right,
int padding_cols_left, int padding_cols_right,
typename TTypes<T, 4, int>::Tensor out) {
CudaLaunchConfig config = GetCudaLaunchConfig(out.size(), d);
Dimension<4> input_dims;
for (int i = 0; i < 4; i++) {
input_dims[i] = in.dimension(i);
}
Dimension<4> output_dims;
for (int i = 0; i < 4; i++) {
output_dims[i] = out.dimension(i);
}
PadInputCustomKernel<
T><<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
config.virtual_thread_count, in.data(), input_dims, out.data(),
output_dims, padding_rows_left, padding_cols_left);
}
};
// A GPU helper functor that converts NHWC TensorFlow data format to
// NCHW format that is accepted by Cudnn.
template <typename T>
struct NHWCToNCHW<GPUDevice, T> {
typedef GPUDevice Device;
void operator()(const Device& d, typename TTypes<T, 4>::ConstTensor in,
typename TTypes<T, 4>::Tensor out) {
Dimension<3> combined_dims;
combined_dims[0] = in.dimension(0);
combined_dims[1] = in.dimension(1) * in.dimension(2);
combined_dims[2] = in.dimension(3);
CudaLaunchConfig config = GetCudaLaunchConfig(out.size(), d);
SwapDimension1And2InTensor3<
T><<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
config.virtual_thread_count, in.data(), combined_dims, out.data());
}
};
// A GPU helper functor that converts NCHW Cudnn data format to NHWC TensorFlow
// Format.
template <typename T>
struct NCHWToNHWC<GPUDevice, T> {
typedef GPUDevice Device;
void operator()(const Device& d, typename TTypes<T, 4>::ConstTensor in,
typename TTypes<T, 4>::Tensor out) {
Dimension<3> combined_dims;
combined_dims[0] = in.dimension(0);
combined_dims[1] = in.dimension(1);
combined_dims[2] = in.dimension(2) * in.dimension(3);
CudaLaunchConfig config = GetCudaLaunchConfig(out.size(), d);
SwapDimension1And2InTensor3<
T><<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
config.virtual_thread_count, in.data(), combined_dims, out.data());
}
};
} // namespace functor
template struct functor::ShuffleAndReverse<GPUDevice, float, 4, int>;
template struct functor::ShuffleAndReverse<GPUDevice, float, 4,
Eigen::DenseIndex>;
template struct functor::TransformFilter<GPUDevice, float, int>;
template struct functor::ReverseTransformFilter<GPUDevice, float>;
template struct functor::PadInput<GPUDevice, float, int>;
template struct functor::TransformDepth<GPUDevice, float, int>;
// TODO(jiayq): currently pooling ops still use DenseIndex, so I am keeping it
// here.
template struct functor::TransformDepth<GPUDevice, float, Eigen::DenseIndex>;
template struct functor::NHWCToNCHW<GPUDevice, float>;
template struct functor::NCHWToNHWC<GPUDevice, float>;
} // namespace tensorflow

View File

@ -13,10 +13,11 @@
namespace tensorflow {
template <class Scalar, bool SupportsBatchOperationT>
class DeterminantOp : public LinearAlgebraOp<Scalar, SupportsBatchOperationT> {
class DeterminantOp
: public UnaryLinearAlgebraOp<Scalar, SupportsBatchOperationT> {
public:
explicit DeterminantOp(OpKernelConstruction* context)
: LinearAlgebraOp<Scalar, SupportsBatchOperationT>(context) {}
: UnaryLinearAlgebraOp<Scalar, SupportsBatchOperationT>(context) {}
~DeterminantOp() override {}
TensorShape GetOutputMatrixShape(
@ -34,9 +35,10 @@ class DeterminantOp : public LinearAlgebraOp<Scalar, SupportsBatchOperationT> {
}
}
using typename LinearAlgebraOp<Scalar, SupportsBatchOperationT>::MatrixMap;
using
typename LinearAlgebraOp<Scalar, SupportsBatchOperationT>::ConstMatrixMap;
typename UnaryLinearAlgebraOp<Scalar, SupportsBatchOperationT>::MatrixMap;
using typename UnaryLinearAlgebraOp<Scalar,
SupportsBatchOperationT>::ConstMatrixMap;
void ComputeMatrix(OpKernelContext* context, const ConstMatrixMap& input,
MatrixMap* output) override {

View File

@ -9,7 +9,7 @@
namespace tensorflow {
template <typename T>
template <typename T, typename TARGET_T>
class InTopK : public OpKernel {
public:
explicit InTopK(OpKernelConstruction* context) : OpKernel(context) {
@ -29,7 +29,7 @@ class InTopK : public OpKernel {
" must match length of targets ",
targets_in.dim_size(0)));
const auto& predictions = predictions_in.matrix<T>();
const auto& targets = targets_in.vec<int>();
const auto& targets = targets_in.vec<TARGET_T>();
Tensor* t_out = nullptr;
OP_REQUIRES_OK(context,
@ -53,6 +53,13 @@ class InTopK : public OpKernel {
int k_;
};
REGISTER_KERNEL_BUILDER(Name("InTopK").Device(DEVICE_CPU), InTopK<float>);
REGISTER_KERNEL_BUILDER(Name("InTopK")
.Device(DEVICE_CPU)
.TypeConstraint<int32>("T"),
InTopK<float, int32>);
REGISTER_KERNEL_BUILDER(Name("InTopK")
.Device(DEVICE_CPU)
.TypeConstraint<int64>("T"),
InTopK<float, int64>);
} // namespace tensorflow

View File

@ -2,29 +2,23 @@
namespace tensorflow {
void LinearAlgebraOpBase::Compute(OpKernelContext* context) {
void UnaryLinearAlgebraOpBase::Compute(OpKernelContext* context) {
const Tensor& in = context->input(0);
const int input_rank = GetInputMatrixRank();
OP_REQUIRES(
context, input_rank == 2,
errors::InvalidArgument("Only matrix inputs are supported so far."));
const int input_rank = in.dims();
if (SupportsBatchOperation()) {
OP_REQUIRES(context, in.dims() > input_rank,
errors::InvalidArgument("Input tensor must have rank >= %d",
input_rank + 1));
OP_REQUIRES(context, input_rank >= 2,
errors::InvalidArgument("Input tensor must have rank >= 2"));
} else {
OP_REQUIRES(context, in.dims() == input_rank,
errors::InvalidArgument("Input tensor must have rank == %d",
input_rank));
OP_REQUIRES(context, input_rank == 2,
errors::InvalidArgument("Input tensor must have rank == 2"));
}
// If the tensor rank is greater than input_rank, we consider the inner-most
// dimensions as matrices, and loop over all the other outer
// dimensions to compute the results.
// TODO(kalakris): Only matrix inputs are currently supported.
const int row_dimension = in.dims() - 2;
const int col_dimension = in.dims() - 1;
const int row_dimension = input_rank - 2;
const int col_dimension = input_rank - 1;
const int64 num_rows = in.dim_size(row_dimension);
const int64 num_cols = in.dim_size(col_dimension);
const TensorShape input_matrix_shape = TensorShape({num_rows, num_cols});
@ -36,16 +30,19 @@ void LinearAlgebraOpBase::Compute(OpKernelContext* context) {
int num_matrices = 1;
// The output has the shape of all the outer dimensions of the input
// except for the last two, plus the output_matrix_shape (if the output
// is not scalar). This still assumes that each input matrix is
// 2-dimensional, in accordance with the TODO above.
// is not scalar). This assumes that each input matrix is
// 2-dimensional.
TensorShape output_shape;
if (in.dims() == 2) {
if (input_rank == 2) {
output_shape = output_matrix_shape;
} else {
for (int dim = 0; dim <= in.dims() - 3; ++dim) {
// Add the common outer dimensions.
for (int dim = 0; dim < input_rank - 2; ++dim) {
num_matrices *= in.dim_size(dim);
output_shape.AddDim(in.dim_size(dim));
}
// Add the inner dimensions that depend on the operation implemented by the
// derived class.
for (int dim = 0; dim < output_matrix_shape.dims(); ++dim) {
output_shape.AddDim(output_matrix_shape.dim_size(dim));
}
@ -68,7 +65,7 @@ void LinearAlgebraOpBase::Compute(OpKernelContext* context) {
}
template <typename Scalar, bool SupportsBatchOperationT>
void LinearAlgebraOp<Scalar, SupportsBatchOperationT>::ComputeMatrix(
void UnaryLinearAlgebraOp<Scalar, SupportsBatchOperationT>::ComputeMatrix(
OpKernelContext* context, int64 matrix_index, const Tensor& in,
const TensorShape& input_matrix_shape, Tensor* out,
const TensorShape& output_matrix_shape) {
@ -90,10 +87,11 @@ void LinearAlgebraOp<Scalar, SupportsBatchOperationT>::ComputeMatrix(
ComputeMatrix(context, input, &output);
}
// Explicitly instantiate LinearAlgebraOp for the scalar types we expect to use.
template class LinearAlgebraOp<float, false>;
template class LinearAlgebraOp<float, true>;
template class LinearAlgebraOp<double, false>;
template class LinearAlgebraOp<double, true>;
// Explicitly instantiate UnaryLinearAlgebraOp for the scalar types we expect to
// use.
template class UnaryLinearAlgebraOp<float, false>;
template class UnaryLinearAlgebraOp<float, true>;
template class UnaryLinearAlgebraOp<double, false>;
template class UnaryLinearAlgebraOp<double, true>;
} // namespace tensorflow

View File

@ -1,6 +1,10 @@
#ifndef TENSORFLOW_KERNELS_LINALG_OPS_COMMON_H_
#define TENSORFLOW_KERNELS_LINALG_OPS_COMMON_H_
// Classes to support linear algebra functionality, similar to the numpy.linalg
// module. Supports batch computation on several matrices at once, sharding the
// computations across different threads if necessary.
#define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@ -16,21 +20,12 @@
namespace tensorflow {
// A base class to support linear algebra functionality, similar to the
// numpy.linalg module. Supports batch computation on several matrices at once,
// sharding the computations across different threads if necessary.
//
// TODO(kalakris): This needs to be expanded to support binary inputs, and
// multiple outputs.
class LinearAlgebraOpBase : public OpKernel {
// Base class for unary linear algebra operators.
class UnaryLinearAlgebraOpBase : public OpKernel {
public:
explicit LinearAlgebraOpBase(OpKernelConstruction* context)
explicit UnaryLinearAlgebraOpBase(OpKernelConstruction* context)
: OpKernel(context) {}
~LinearAlgebraOpBase() override {}
// Return the expected rank of the input.
// TODO(kalakris): This should be a virtual function to support vector inputs.
int GetInputMatrixRank() { return 2; }
~UnaryLinearAlgebraOpBase() override {}
// Return the output shape of each individual matrix operation. Must be
// rank 0, 1, or 2. Scalar outputs are rank 0.
@ -62,7 +57,8 @@ class LinearAlgebraOpBase : public OpKernel {
// address
// out->flat<Scalar>().data() +
// matrix_index * output_matrix_shape.num_elements().
// The LinearAlgebraOp<Scalar> class below has functionality which performs
// The UnaryLinearAlgebraOp<Scalar> class below has functionality which
// performs
// this mapping and presents an interface based on the Eigen::MatrixBase API.
virtual void ComputeMatrix(OpKernelContext* context, int64 matrix_index,
const Tensor& in,
@ -72,8 +68,6 @@ class LinearAlgebraOpBase : public OpKernel {
void Compute(OpKernelContext* context) override;
};
// A base class for linear algebra ops templated on the scalar type.
//
// This base class encapsulates the functionality of mapping the input and
// output tensors using Eigen::Map, so that the Eigen::MatrixBase API may be
// directly used by derived classes.
@ -81,10 +75,10 @@ class LinearAlgebraOpBase : public OpKernel {
// will allow the Op to process batches of matrices (rank >= 3); if set to
// false the Op will only accept rank 2 inputs.
template <typename Scalar, bool SupportsBatchOperationT>
class LinearAlgebraOp : public LinearAlgebraOpBase {
class UnaryLinearAlgebraOp : public UnaryLinearAlgebraOpBase {
public:
explicit LinearAlgebraOp(OpKernelConstruction* context)
: LinearAlgebraOpBase(context) {}
explicit UnaryLinearAlgebraOp(OpKernelConstruction* context)
: UnaryLinearAlgebraOpBase(context) {}
using Matrix =
Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
@ -100,18 +94,18 @@ class LinearAlgebraOp : public LinearAlgebraOpBase {
bool SupportsBatchOperation() final { return SupportsBatchOperationT; }
// A concrete implementation of LinearAlgebraOpBase::ComputeMatrix().
// A concrete implementation of UnaryLinearAlgebraOpBase::ComputeMatrix().
void ComputeMatrix(OpKernelContext* context, int64 matrix_index,
const Tensor& in, const TensorShape& input_matrix_shape,
Tensor* out, const TensorShape& output_matrix_shape) final;
};
// Declare that LinearAlgebraOp is explicitly instantiated in
// Declare that UnaryLinearAlgebraOp is explicitly instantiated in
// linalg_ops_common.cc for float and double.
extern template class LinearAlgebraOp<float, false>;
extern template class LinearAlgebraOp<float, true>;
extern template class LinearAlgebraOp<double, false>;
extern template class LinearAlgebraOp<double, true>;
extern template class UnaryLinearAlgebraOp<float, false>;
extern template class UnaryLinearAlgebraOp<float, true>;
extern template class UnaryLinearAlgebraOp<double, false>;
extern template class UnaryLinearAlgebraOp<double, true>;
} // namespace tensorflow

View File

@ -15,10 +15,10 @@ namespace tensorflow {
template <class Scalar, bool SupportsBatchOperationT>
class MatrixInverseOp
: public LinearAlgebraOp<Scalar, SupportsBatchOperationT> {
: public UnaryLinearAlgebraOp<Scalar, SupportsBatchOperationT> {
public:
explicit MatrixInverseOp(OpKernelConstruction* context)
: LinearAlgebraOp<Scalar, SupportsBatchOperationT>(context) {}
: UnaryLinearAlgebraOp<Scalar, SupportsBatchOperationT>(context) {}
~MatrixInverseOp() override {}
TensorShape GetOutputMatrixShape(
@ -36,10 +36,11 @@ class MatrixInverseOp
}
}
using typename LinearAlgebraOp<Scalar, SupportsBatchOperationT>::Matrix;
using typename LinearAlgebraOp<Scalar, SupportsBatchOperationT>::MatrixMap;
using typename UnaryLinearAlgebraOp<Scalar, SupportsBatchOperationT>::Matrix;
using
typename LinearAlgebraOp<Scalar, SupportsBatchOperationT>::ConstMatrixMap;
typename UnaryLinearAlgebraOp<Scalar, SupportsBatchOperationT>::MatrixMap;
using typename UnaryLinearAlgebraOp<Scalar,
SupportsBatchOperationT>::ConstMatrixMap;
void ComputeMatrix(OpKernelContext* context, const ConstMatrixMap& input,
MatrixMap* output) override {

View File

@ -156,4 +156,15 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL);
#endif // GOOGLE_CUDA
// A special GPU kernel for int32.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
// registration requires all int32 inputs and outputs to be in host memory.
REGISTER_KERNEL_BUILDER(Name("Pad")
.Device(DEVICE_GPU)
.TypeConstraint<int32>("T")
.HostMemory("input")
.HostMemory("paddings")
.HostMemory("output"),
PadOp<CPUDevice, int32>);
} // end namespace tensorflow

View File

@ -167,26 +167,25 @@ void DnnPoolingGradOp<T>::Compute(
out_backprop.dim_size(1), out_backprop.dim_size(2)}),
&transformed_output_backprop));
auto nhwc_to_nchw = Eigen::DSizes<Eigen::DenseIndex, 4>(0, 3, 1, 2);
if (tensor_in) {
// For AvgPoolGrad, the original input tensor is not necessary. However,
// cudnn still requires them to run, although they do not affect the
// results.
functor::TransformDepth<GPUDevice, T, Eigen::DenseIndex>()(
context->eigen_device<Device>(), tensor_in->tensor<T, 4>(),
nhwc_to_nchw, transformed_input.tensor<T, 4>());
functor::NHWCToNCHW<GPUDevice, T>()(context->eigen_device<Device>(),
tensor_in->tensor<T, 4>(),
transformed_input.tensor<T, 4>());
}
if (tensor_out) {
// For AvgPoolGrad, the original output tensor is not necessary. However,
// cudnn still requires them to run, although they do not affect the
// results.
functor::TransformDepth<GPUDevice, T, Eigen::DenseIndex>()(
context->eigen_device<Device>(), tensor_out->tensor<T, 4>(),
nhwc_to_nchw, transformed_output.tensor<T, 4>());
functor::NHWCToNCHW<GPUDevice, T>()(context->eigen_device<Device>(),
tensor_out->tensor<T, 4>(),
transformed_output.tensor<T, 4>());
}
functor::TransformDepth<GPUDevice, T, Eigen::DenseIndex>()(
functor::NHWCToNCHW<GPUDevice, T>()(
context->eigen_device<Device>(), out_backprop.tensor<T, 4>(),
nhwc_to_nchw, transformed_output_backprop.tensor<T, 4>());
transformed_output_backprop.tensor<T, 4>());
/// Get ready to call cudnn
perftools::gputools::dnn::PoolingDescriptor pooling_desc;
@ -238,11 +237,10 @@ void DnnPoolingGradOp<T>::Compute(
/// Transform the output data from NCHW back to NHWC
auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
auto nchw_to_nhwc = Eigen::DSizes<Eigen::DenseIndex, 4>(0, 2, 3, 1);
functor::TransformDepth<GPUDevice, T, Eigen::DenseIndex>()(
functor::NCHWToNHWC<GPUDevice, T>()(
context->eigen_device<Device>(),
toConstTensor(transformed_input_backprop).template tensor<T, 4>(),
nchw_to_nhwc, output->tensor<T, 4>());
output->tensor<T, 4>());
}
template class DnnPoolingGradOp<float>;

View File

@ -111,7 +111,8 @@ class EnqueueManyOp : public QueueAccessOpKernel {
for (DataType dt : queue->component_dtypes()) {
expected_inputs.push_back(dt);
}
OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {}));
OP_REQUIRES_OK_ASYNC(ctx, ctx->MatchSignature(expected_inputs, {}),
callback);
QueueInterface::Tuple tuple;
OpInputList components;

View File

@ -72,7 +72,7 @@ REGISTER_KERNEL_BUILDER(Name("Rank").Device(DEVICE_CPU).HostMemory("output"),
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
#undef REGISTER_GPU_KERNEL
// A special GPU kernel for int32.
// A special GPU kernel for int32 and bool.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
// registration requires all int32 inputs and outputs to be in host memory.
REGISTER_KERNEL_BUILDER(Name("Rank")
@ -82,6 +82,13 @@ REGISTER_KERNEL_BUILDER(Name("Rank")
.HostMemory("output"),
RankOp);
REGISTER_KERNEL_BUILDER(Name("Rank")
.Device(DEVICE_GPU)
.TypeConstraint<bool>("T")
.HostMemory("input")
.HostMemory("output"),
RankOp);
class SizeOp : public OpKernel {
public:
explicit SizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}

View File

@ -54,6 +54,7 @@ REGISTER_KERNEL_BUILDER(Name("ScalarSummary")
.TypeConstraint<double>("T"),
SummaryScalarOp<double>);
template <typename T>
class SummaryHistoOp : public OpKernel {
public:
// SummaryHistoOp could be extended to take a list of custom bucket
@ -63,13 +64,13 @@ class SummaryHistoOp : public OpKernel {
void Compute(OpKernelContext* c) override {
const Tensor& tags = c->input(0);
const Tensor& values = c->input(1);
const auto flat = values.flat<float>();
const auto flat = values.flat<T>();
OP_REQUIRES(c, TensorShapeUtils::IsLegacyScalar(tags.shape()),
errors::InvalidArgument("tags must be scalar"));
// Build histogram of values in "values" tensor
histogram::Histogram histo;
for (int64 i = 0; i < flat.size(); i++) {
float v = flat(i);
T v = flat(i);
if (!std::isfinite(v)) {
c->SetStatus(
errors::OutOfRange("Nan in summary histogram for: ", name()));
@ -89,8 +90,14 @@ class SummaryHistoOp : public OpKernel {
}
};
REGISTER_KERNEL_BUILDER(Name("HistogramSummary").Device(DEVICE_CPU),
SummaryHistoOp);
REGISTER_KERNEL_BUILDER(Name("HistogramSummary")
.Device(DEVICE_CPU)
.TypeConstraint<float>("T"),
SummaryHistoOp<float>);
REGISTER_KERNEL_BUILDER(Name("HistogramSummary")
.Device(DEVICE_CPU)
.TypeConstraint<double>("T"),
SummaryHistoOp<double>);
struct HistogramResource : public ResourceBase {
histogram::ThreadSafeHistogram histogram;

View File

@ -122,17 +122,17 @@ TEST_F(SummaryScalarOpTest, Error_WrongDimsValues) {
// --------------------------------------------------------------------------
class SummaryHistoOpTest : public OpsTestBase {
protected:
void MakeOp() {
void MakeOp(DataType dt) {
ASSERT_OK(NodeDefBuilder("myop", "HistogramSummary")
.Input(FakeInput())
.Input(FakeInput())
.Input(FakeInput(dt))
.Finalize(node_def()));
ASSERT_OK(InitOp());
}
};
TEST_F(SummaryHistoOpTest, Simple) {
MakeOp();
TEST_F(SummaryHistoOpTest, SimpleFloat) {
MakeOp(DT_FLOAT);
// Feed and run
AddInputFromArray<string>(TensorShape({}), {"taghisto"});
@ -159,8 +159,36 @@ TEST_F(SummaryHistoOpTest, Simple) {
histo.ToString());
}
TEST_F(SummaryHistoOpTest, SimpleDouble) {
MakeOp(DT_DOUBLE);
// Feed and run
AddInputFromArray<string>(TensorShape({}), {"taghisto"});
AddInputFromArray<double>(TensorShape({3, 2}), {0.1, -0.7, 4.1, 4., 5., 4.});
ASSERT_OK(RunOpKernel());
// Check the output size.
Tensor* out_tensor = GetOutput(0);
ASSERT_EQ(0, out_tensor->dims());
Summary summary;
ParseProtoUnlimited(&summary, out_tensor->scalar<string>()());
ASSERT_EQ(summary.value_size(), 1);
EXPECT_EQ(summary.value(0).tag(), "taghisto");
histogram::Histogram histo;
EXPECT_TRUE(histo.DecodeFromProto(summary.value(0).histo()));
EXPECT_EQ(
"Count: 6 Average: 2.7500 StdDev: 2.20\n"
"Min: -0.7000 Median: 3.9593 Max: 5.0000\n"
"------------------------------------------------------\n"
"[ -0.76, -0.69 ) 1 16.667% 16.667% ###\n"
"[ 0.093, 0.1 ) 1 16.667% 33.333% ###\n"
"[ 3.8, 4.2 ) 3 50.000% 83.333% ##########\n"
"[ 4.6, 5.1 ) 1 16.667% 100.000% ###\n",
histo.ToString());
}
TEST_F(SummaryHistoOpTest, Error_WrongDimsTags) {
MakeOp();
MakeOp(DT_FLOAT);
// Feed and run
AddInputFromArray<string>(TensorShape({2, 1}), {"tag1", "tag2"});
@ -170,7 +198,7 @@ TEST_F(SummaryHistoOpTest, Error_WrongDimsTags) {
}
TEST_F(SummaryHistoOpTest, Error_TooManyTagValues) {
MakeOp();
MakeOp(DT_FLOAT);
// Feed and run
AddInputFromArray<string>(TensorShape({2}), {"tag1", "tag2"});

View File

@ -36,9 +36,10 @@ Status TypedQueue<SubQueue>::Initialize() {
}
if (!component_shapes_.empty() &&
component_dtypes_.size() != component_shapes_.size()) {
return errors::InvalidArgument("Different number of component types (",
component_dtypes_.size(), ") vs. shapes (",
component_shapes_.size(), ").");
return errors::InvalidArgument(
"Different number of component types. ", "Types: ",
DataTypeSliceString(component_dtypes_), ", Shapes: ",
ShapeListString(component_shapes_));
}
mutex_lock lock(mu_);

View File

@ -102,24 +102,43 @@ void Histogram::Add(double value) {
double Histogram::Median() const { return Percentile(50.0); }
// Linearly map the variable x from [x0, x1] unto [y0, y1]
double Histogram::Remap(double x, double x0, double x1, double y0,
double y1) const {
return y0 + (x - x0) / (x1 - x0) * (y1 - y0);
}
// Pick tight left-hand-side and right-hand-side bounds and then
// interpolate a histogram value at percentile p
double Histogram::Percentile(double p) const {
if (num_ == 0.0) return 0.0;
double threshold = num_ * (p / 100.0);
double sum = 0;
for (size_t b = 0; b < buckets_.size(); b++) {
sum += buckets_[b];
if (sum >= threshold) {
// Scale linearly within this bucket
double left_point = (b == 0) ? min_ : bucket_limits_[b - 1];
double right_point = bucket_limits_[b];
double left_sum = sum - buckets_[b];
double right_sum = sum;
double pos = (threshold - left_sum) / (right_sum - left_sum);
double r = left_point + (right_point - left_point) * pos;
if (r < min_) r = min_;
if (r > max_) r = max_;
return r;
double cumsum_prev = 0;
for (size_t i = 0; i < buckets_.size(); i++) {
double cumsum = cumsum_prev + buckets_[i];
// Find the first bucket whose cumsum >= threshold
if (cumsum >= threshold) {
// Prevent divide by 0 in remap which happens if cumsum == cumsum_prev
// This should only get hit when p == 0, cumsum == 0, and cumsum_prev == 0
if (cumsum == cumsum_prev) {
continue;
}
// Calculate the lower bound of interpolation
double lhs = (i == 0 || cumsum_prev == 0) ? min_ : bucket_limits_[i - 1];
lhs = std::max(lhs, min_);
// Calculate the upper bound of interpolation
double rhs = bucket_limits_[i];
rhs = std::min(rhs, max_);
double weight = Remap(threshold, cumsum_prev, cumsum, lhs, rhs);
return weight;
}
cumsum_prev = cumsum;
}
return max_;
}

View File

@ -77,6 +77,8 @@ class Histogram {
gtl::ArraySlice<double> bucket_limits_;
std::vector<double> buckets_;
double Remap(double x, double x0, double x1, double y0, double y1) const;
TF_DISALLOW_COPY_AND_ASSIGN(Histogram);
};

View File

@ -51,15 +51,41 @@ TEST(Histogram, CustomBuckets) {
Validate(h);
}
TEST(Histogram, Percentile) {
TEST(Histogram, Median) {
Histogram h({0, 10, 100, DBL_MAX});
h.Add(-2);
h.Add(-2);
h.Add(0);
double median = h.Percentile(50.0);
double median = h.Median();
EXPECT_EQ(median, -0.5);
}
TEST(Histogram, Percentile) {
// 10%, 30%, 40%, 20%
Histogram h({1, 2, 3, 4});
// 10% first bucket
h.Add(-1.0);
// 30% second bucket
h.Add(1.5);
h.Add(1.5);
h.Add(1.5);
// 40% third bucket
h.Add(2.5);
h.Add(2.5);
h.Add(2.5);
h.Add(2.5);
// 20% fourth bucket
h.Add(3.5);
h.Add(3.9);
EXPECT_EQ(h.Percentile(0), -1.0); // -1.0 = histo.min_
EXPECT_EQ(h.Percentile(25), 1.5); // 1.5 = remap(25, 10, 40, 1, 2)
EXPECT_EQ(h.Percentile(50), 2.25); // 2.25 = remap(50, 40, 80, 2, 3)
EXPECT_EQ(h.Percentile(75), 2.875); // 2.875 = remap(75, 40, 80, 2, 3)
EXPECT_EQ(h.Percentile(90), 3.45); // 3.45 = remap(90, 80, 100, 3, 3.9)
EXPECT_EQ(h.Percentile(100), 3.9); // 3.9 = histo.max_
}
TEST(Histogram, Basic) {
Histogram h;
for (int i = 0; i < 100; i++) {

View File

@ -62,9 +62,11 @@ Batch normalization.
t: A 4D input Tensor.
m: A 1D mean Tensor with size matching the last dimension of t.
This is the first output from MovingMoments.
This is the first output from tf.nn.moments,
or a saved moving average thereof.
v: A 1D variance Tensor with size matching the last dimension of t.
This is the second output from MovingMoments.
This is the second output from tf.nn.moments,
or a saved moving average thereof.
beta: A 1D beta Tensor with size matching the last dimension of t.
An offset to be added to the normalized tensor.
gamma: A 1D gamma Tensor with size matching the last dimension of t.
@ -94,9 +96,11 @@ Gradients for batch normalization.
t: A 4D input Tensor.
m: A 1D mean Tensor with size matching the last dimension of t.
This is the first output from MovingMoments.
This is the first output from tf.nn.moments,
or a saved moving average thereof.
v: A 1D variance Tensor with size matching the last dimension of t.
This is the second output from MovingMoments.
This is the second output from tf.nn.moments,
or a saved moving average thereof.
gamma: A 1D gamma Tensor with size matching the last dimension of t.
If "scale_after_normalization" is true, this Tensor will be multiplied
with the normalized Tensor.
@ -488,10 +492,11 @@ backprop: backpropagated gradients (batch_size x num_classes matrix).
// --------------------------------------------------------------------------
REGISTER_OP("InTopK")
.Attr("k: int")
.Input("predictions: float")
.Input("targets: int32")
.Input("targets: T")
.Output("precision: bool")
.Attr("k: int")
.Attr("T: {int32, int64} = DT_INT32")
.Doc(R"doc(
Says whether the targets are in the top K predictions.

File diff suppressed because it is too large Load Diff

View File

@ -23,8 +23,9 @@ summary: Scalar. Serialized `Summary` protocol buffer.
REGISTER_OP("HistogramSummary")
.Input("tag: string")
.Input("values: float")
.Input("values: T")
.Output("summary: string")
.Attr("T: {float, double} = DT_FLOAT")
.Doc(R"doc(
Outputs a `Summary` protocol buffer with a histogram.

View File

@ -4,8 +4,8 @@
// Import whatever namespace protobuf comes from into the
// ::tensorflow::protobuf namespace.
//
// TensorFlow code should the ::tensorflow::protobuf namespace to refer
// to all protobuf APIs.
// TensorFlow code should use the ::tensorflow::protobuf namespace to
// refer to all protobuf APIs.
#include "tensorflow/core/platform/port.h"
#if defined(PLATFORM_GOOGLE)

View File

@ -0,0 +1,24 @@
#ifndef THIRD_PARTY_TENSORFLOW_CORE_PUBLIC_VERSION_H_
#define THIRD_PARTY_TENSORFLOW_CORE_PUBLIC_VERSION_H_
// TensorFlow uses semantic versioning, see http://semver.org/.
#define TF_MAJOR_VERSION 0
#define TF_MINOR_VERSION 5
#define TF_PATCH_VERSION 0
// TF_VERSION_SUFFIX is non-empty for pre-releases (e.g. "-alpha", "-alpha.1",
// "-beta", "-rc", "-rc.1")
#define TF_VERSION_SUFFIX ""
#define TF_STR_HELPER(x) #x
#define TF_STR(x) TF_STR_HELPER(x)
// e.g. "0.5.0" or "0.6.0-alpha".
#define TF_VERSION_STRING \
(TF_STR(TF_MAJOR_VERSION) "." TF_STR(TF_MINOR_VERSION) "." TF_STR( \
TF_PATCH_VERSION) TF_VERSION_SUFFIX)
// TODO(josh11b): Public API functions for exporting the above.
#endif // THIRD_PARTY_TENSORFLOW_CORE_PUBLIC_VERSION_H_

View File

@ -0,0 +1,52 @@
#ifndef THIRD_PARTY_TENSORFLOW_CORE_UTIL_CUDA_KERNEL_HELPER_H_
#define THIRD_PARTY_TENSORFLOW_CORE_UTIL_CUDA_KERNEL_HELPER_H_
#if GOOGLE_CUDA
#include <algorithm>
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
i += blockDim.x * gridDim.x)
namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
struct CudaLaunchConfig {
// Logical number of thread that works on the elements. If each logic thread
// works on exactly a single element, this is the same as the working element
// count.
int virtual_thread_count = -1;
// Number of threads per block.
int thread_per_block = -1;
// Number of blocks for Cuda kernel launch.
int block_count = -1;
};
// Calculate the Cuda launch config we should use for a kernel launch.
// This is assuming the kernel is quite simple and will largely be
// memory-limited.
inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count,
const GPUDevice& d) {
const int virtual_thread_count = work_element_count;
const int physical_thread_count = std::min(
d.getNumCudaMultiProcessors() * d.maxCudaThreadsPerMultiProcessor(),
virtual_thread_count);
const int thread_per_block = std::min(1024, d.maxCudaThreadsPerBlock());
const int block_count = std::min(
(physical_thread_count + thread_per_block - 1) / thread_per_block,
d.getNumCudaMultiProcessors());
CudaLaunchConfig config;
config.virtual_thread_count = virtual_thread_count;
config.thread_per_block = thread_per_block;
config.block_count = block_count;
return config;
}
} // namespace tensorflow
#endif // GOOGLE_CUDA
#endif // THIRD_PARTY_TENSORFLOW_CORE_UTIL_CUDA_KERNEL_HELPER_H_

View File

@ -1,4 +1,4 @@
# Class `tensorflow::Env` <a class="md-anchor" id="AUTOGENERATED-class--tensorflow--env-"></a>
# Class `tensorflow::Env`
An interface used by the tensorflow implementation to access operating system functionality like the filesystem etc.
@ -6,7 +6,7 @@ Callers may wish to provide a custom Env object to get fine grain control.
All Env implementations are safe for concurrent access from multiple threads without any external synchronization.
##Member Summary <a class="md-anchor" id="AUTOGENERATED-member-summary"></a>
##Member Summary
* [`tensorflow::Env::Env()`](#tensorflow_Env_Env)
* [`virtual tensorflow::Env::~Env()`](#virtual_tensorflow_Env_Env)
@ -39,21 +39,21 @@ All Env implementations are safe for concurrent access from multiple threads wit
* [`static Env* tensorflow::Env::Default()`](#static_Env_tensorflow_Env_Default)
* Returns a default environment suitable for the current operating system.
##Member Details <a class="md-anchor" id="AUTOGENERATED-member-details"></a>
##Member Details
#### `tensorflow::Env::Env()` <a class="md-anchor" id="tensorflow_Env_Env"></a>
#### `tensorflow::Env::Env()` {#tensorflow_Env_Env}
#### `virtual tensorflow::Env::~Env()` <a class="md-anchor" id="virtual_tensorflow_Env_Env"></a>
#### `virtual tensorflow::Env::~Env()` {#virtual_tensorflow_Env_Env}
#### `virtual Status tensorflow::Env::NewRandomAccessFile(const string &fname, RandomAccessFile **result)=0` <a class="md-anchor" id="virtual_Status_tensorflow_Env_NewRandomAccessFile"></a>
#### `virtual Status tensorflow::Env::NewRandomAccessFile(const string &fname, RandomAccessFile **result)=0` {#virtual_Status_tensorflow_Env_NewRandomAccessFile}
Creates a brand new random access read-only file with the specified name.
@ -61,7 +61,7 @@ On success, stores a pointer to the new file in *result and returns OK. On failu
The returned file may be concurrently accessed by multiple threads.
#### `virtual Status tensorflow::Env::NewWritableFile(const string &fname, WritableFile **result)=0` <a class="md-anchor" id="virtual_Status_tensorflow_Env_NewWritableFile"></a>
#### `virtual Status tensorflow::Env::NewWritableFile(const string &fname, WritableFile **result)=0` {#virtual_Status_tensorflow_Env_NewWritableFile}
Creates an object that writes to a new file with the specified name.
@ -69,7 +69,7 @@ Deletes any existing file with the same name and creates a new file. On success,
The returned file will only be accessed by one thread at a time.
#### `virtual Status tensorflow::Env::NewAppendableFile(const string &fname, WritableFile **result)=0` <a class="md-anchor" id="virtual_Status_tensorflow_Env_NewAppendableFile"></a>
#### `virtual Status tensorflow::Env::NewAppendableFile(const string &fname, WritableFile **result)=0` {#virtual_Status_tensorflow_Env_NewAppendableFile}
Creates an object that either appends to an existing file, or writes to a new file (if the file does not exist to begin with).
@ -77,67 +77,67 @@ On success, stores a pointer to the new file in *result and returns OK. On failu
The returned file will only be accessed by one thread at a time.
#### `virtual bool tensorflow::Env::FileExists(const string &fname)=0` <a class="md-anchor" id="virtual_bool_tensorflow_Env_FileExists"></a>
#### `virtual bool tensorflow::Env::FileExists(const string &fname)=0` {#virtual_bool_tensorflow_Env_FileExists}
Returns true iff the named file exists.
#### `virtual Status tensorflow::Env::GetChildren(const string &dir, std::vector< string > *result)=0` <a class="md-anchor" id="virtual_Status_tensorflow_Env_GetChildren"></a>
#### `virtual Status tensorflow::Env::GetChildren(const string &dir, std::vector< string > *result)=0` {#virtual_Status_tensorflow_Env_GetChildren}
Stores in *result the names of the children of the specified directory. The names are relative to "dir".
Original contents of *results are dropped.
#### `virtual Status tensorflow::Env::DeleteFile(const string &fname)=0` <a class="md-anchor" id="virtual_Status_tensorflow_Env_DeleteFile"></a>
#### `virtual Status tensorflow::Env::DeleteFile(const string &fname)=0` {#virtual_Status_tensorflow_Env_DeleteFile}
Deletes the named file.
#### `virtual Status tensorflow::Env::CreateDir(const string &dirname)=0` <a class="md-anchor" id="virtual_Status_tensorflow_Env_CreateDir"></a>
#### `virtual Status tensorflow::Env::CreateDir(const string &dirname)=0` {#virtual_Status_tensorflow_Env_CreateDir}
Creates the specified directory.
#### `virtual Status tensorflow::Env::DeleteDir(const string &dirname)=0` <a class="md-anchor" id="virtual_Status_tensorflow_Env_DeleteDir"></a>
#### `virtual Status tensorflow::Env::DeleteDir(const string &dirname)=0` {#virtual_Status_tensorflow_Env_DeleteDir}
Deletes the specified directory.
#### `virtual Status tensorflow::Env::GetFileSize(const string &fname, uint64 *file_size)=0` <a class="md-anchor" id="virtual_Status_tensorflow_Env_GetFileSize"></a>
#### `virtual Status tensorflow::Env::GetFileSize(const string &fname, uint64 *file_size)=0` {#virtual_Status_tensorflow_Env_GetFileSize}
Stores the size of fname in *file_size.
#### `virtual Status tensorflow::Env::RenameFile(const string &src, const string &target)=0` <a class="md-anchor" id="virtual_Status_tensorflow_Env_RenameFile"></a>
#### `virtual Status tensorflow::Env::RenameFile(const string &src, const string &target)=0` {#virtual_Status_tensorflow_Env_RenameFile}
Renames file src to target. If target already exists, it will be replaced.
#### `virtual uint64 tensorflow::Env::NowMicros()=0` <a class="md-anchor" id="virtual_uint64_tensorflow_Env_NowMicros"></a>
#### `virtual uint64 tensorflow::Env::NowMicros()=0` {#virtual_uint64_tensorflow_Env_NowMicros}
Returns the number of micro-seconds since some fixed point in time. Only useful for computing deltas of time.
#### `virtual void tensorflow::Env::SleepForMicroseconds(int micros)=0` <a class="md-anchor" id="virtual_void_tensorflow_Env_SleepForMicroseconds"></a>
#### `virtual void tensorflow::Env::SleepForMicroseconds(int micros)=0` {#virtual_void_tensorflow_Env_SleepForMicroseconds}
Sleeps/delays the thread for the prescribed number of micro-seconds.
#### `virtual Thread* tensorflow::Env::StartThread(const ThreadOptions &thread_options, const string &name, std::function< void()> fn) TF_MUST_USE_RESULT=0` <a class="md-anchor" id="virtual_Thread_tensorflow_Env_StartThread"></a>
#### `virtual Thread* tensorflow::Env::StartThread(const ThreadOptions &thread_options, const string &name, std::function< void()> fn) TF_MUST_USE_RESULT=0` {#virtual_Thread_tensorflow_Env_StartThread}
Returns a new thread that is running fn() and is identified (for debugging/performance-analysis) by "name".
Caller takes ownership of the result and must delete it eventually (the deletion will block until fn() stops running).
#### `static Env* tensorflow::Env::Default()` <a class="md-anchor" id="static_Env_tensorflow_Env_Default"></a>
#### `static Env* tensorflow::Env::Default()` {#static_Env_tensorflow_Env_Default}
Returns a default environment suitable for the current operating system.

View File

@ -1,10 +1,10 @@
# Class `tensorflow::EnvWrapper` <a class="md-anchor" id="AUTOGENERATED-class--tensorflow--envwrapper-"></a>
# Class `tensorflow::EnvWrapper`
An implementation of Env that forwards all calls to another Env .
May be useful to clients who wish to override just part of the functionality of another Env .
##Member Summary <a class="md-anchor" id="AUTOGENERATED-member-summary"></a>
##Member Summary
* [`tensorflow::EnvWrapper::EnvWrapper(Env *t)`](#tensorflow_EnvWrapper_EnvWrapper)
* Initializes an EnvWrapper that delegates all calls to *t.
@ -38,27 +38,27 @@ May be useful to clients who wish to override just part of the functionality of
* [`Thread* tensorflow::EnvWrapper::StartThread(const ThreadOptions &thread_options, const string &name, std::function< void()> fn) override`](#Thread_tensorflow_EnvWrapper_StartThread)
* Returns a new thread that is running fn() and is identified (for debugging/performance-analysis) by "name".
##Member Details <a class="md-anchor" id="AUTOGENERATED-member-details"></a>
##Member Details
#### `tensorflow::EnvWrapper::EnvWrapper(Env *t)` <a class="md-anchor" id="tensorflow_EnvWrapper_EnvWrapper"></a>
#### `tensorflow::EnvWrapper::EnvWrapper(Env *t)` {#tensorflow_EnvWrapper_EnvWrapper}
Initializes an EnvWrapper that delegates all calls to *t.
#### `virtual tensorflow::EnvWrapper::~EnvWrapper()` <a class="md-anchor" id="virtual_tensorflow_EnvWrapper_EnvWrapper"></a>
#### `virtual tensorflow::EnvWrapper::~EnvWrapper()` {#virtual_tensorflow_EnvWrapper_EnvWrapper}
#### `Env* tensorflow::EnvWrapper::target() const` <a class="md-anchor" id="Env_tensorflow_EnvWrapper_target"></a>
#### `Env* tensorflow::EnvWrapper::target() const` {#Env_tensorflow_EnvWrapper_target}
Returns the target to which this Env forwards all calls.
#### `Status tensorflow::EnvWrapper::NewRandomAccessFile(const string &f, RandomAccessFile **r) override` <a class="md-anchor" id="Status_tensorflow_EnvWrapper_NewRandomAccessFile"></a>
#### `Status tensorflow::EnvWrapper::NewRandomAccessFile(const string &f, RandomAccessFile **r) override` {#Status_tensorflow_EnvWrapper_NewRandomAccessFile}
Creates a brand new random access read-only file with the specified name.
@ -66,7 +66,7 @@ On success, stores a pointer to the new file in *result and returns OK. On failu
The returned file may be concurrently accessed by multiple threads.
#### `Status tensorflow::EnvWrapper::NewWritableFile(const string &f, WritableFile **r) override` <a class="md-anchor" id="Status_tensorflow_EnvWrapper_NewWritableFile"></a>
#### `Status tensorflow::EnvWrapper::NewWritableFile(const string &f, WritableFile **r) override` {#Status_tensorflow_EnvWrapper_NewWritableFile}
Creates an object that writes to a new file with the specified name.
@ -74,7 +74,7 @@ Deletes any existing file with the same name and creates a new file. On success,
The returned file will only be accessed by one thread at a time.
#### `Status tensorflow::EnvWrapper::NewAppendableFile(const string &f, WritableFile **r) override` <a class="md-anchor" id="Status_tensorflow_EnvWrapper_NewAppendableFile"></a>
#### `Status tensorflow::EnvWrapper::NewAppendableFile(const string &f, WritableFile **r) override` {#Status_tensorflow_EnvWrapper_NewAppendableFile}
Creates an object that either appends to an existing file, or writes to a new file (if the file does not exist to begin with).
@ -82,61 +82,61 @@ On success, stores a pointer to the new file in *result and returns OK. On failu
The returned file will only be accessed by one thread at a time.
#### `bool tensorflow::EnvWrapper::FileExists(const string &f) override` <a class="md-anchor" id="bool_tensorflow_EnvWrapper_FileExists"></a>
#### `bool tensorflow::EnvWrapper::FileExists(const string &f) override` {#bool_tensorflow_EnvWrapper_FileExists}
Returns true iff the named file exists.
#### `Status tensorflow::EnvWrapper::GetChildren(const string &dir, std::vector< string > *r) override` <a class="md-anchor" id="Status_tensorflow_EnvWrapper_GetChildren"></a>
#### `Status tensorflow::EnvWrapper::GetChildren(const string &dir, std::vector< string > *r) override` {#Status_tensorflow_EnvWrapper_GetChildren}
Stores in *result the names of the children of the specified directory. The names are relative to "dir".
Original contents of *results are dropped.
#### `Status tensorflow::EnvWrapper::DeleteFile(const string &f) override` <a class="md-anchor" id="Status_tensorflow_EnvWrapper_DeleteFile"></a>
#### `Status tensorflow::EnvWrapper::DeleteFile(const string &f) override` {#Status_tensorflow_EnvWrapper_DeleteFile}
Deletes the named file.
#### `Status tensorflow::EnvWrapper::CreateDir(const string &d) override` <a class="md-anchor" id="Status_tensorflow_EnvWrapper_CreateDir"></a>
#### `Status tensorflow::EnvWrapper::CreateDir(const string &d) override` {#Status_tensorflow_EnvWrapper_CreateDir}
Creates the specified directory.
#### `Status tensorflow::EnvWrapper::DeleteDir(const string &d) override` <a class="md-anchor" id="Status_tensorflow_EnvWrapper_DeleteDir"></a>
#### `Status tensorflow::EnvWrapper::DeleteDir(const string &d) override` {#Status_tensorflow_EnvWrapper_DeleteDir}
Deletes the specified directory.
#### `Status tensorflow::EnvWrapper::GetFileSize(const string &f, uint64 *s) override` <a class="md-anchor" id="Status_tensorflow_EnvWrapper_GetFileSize"></a>
#### `Status tensorflow::EnvWrapper::GetFileSize(const string &f, uint64 *s) override` {#Status_tensorflow_EnvWrapper_GetFileSize}
Stores the size of fname in *file_size.
#### `Status tensorflow::EnvWrapper::RenameFile(const string &s, const string &t) override` <a class="md-anchor" id="Status_tensorflow_EnvWrapper_RenameFile"></a>
#### `Status tensorflow::EnvWrapper::RenameFile(const string &s, const string &t) override` {#Status_tensorflow_EnvWrapper_RenameFile}
Renames file src to target. If target already exists, it will be replaced.
#### `uint64 tensorflow::EnvWrapper::NowMicros() override` <a class="md-anchor" id="uint64_tensorflow_EnvWrapper_NowMicros"></a>
#### `uint64 tensorflow::EnvWrapper::NowMicros() override` {#uint64_tensorflow_EnvWrapper_NowMicros}
Returns the number of micro-seconds since some fixed point in time. Only useful for computing deltas of time.
#### `void tensorflow::EnvWrapper::SleepForMicroseconds(int micros) override` <a class="md-anchor" id="void_tensorflow_EnvWrapper_SleepForMicroseconds"></a>
#### `void tensorflow::EnvWrapper::SleepForMicroseconds(int micros) override` {#void_tensorflow_EnvWrapper_SleepForMicroseconds}
Sleeps/delays the thread for the prescribed number of micro-seconds.
#### `Thread* tensorflow::EnvWrapper::StartThread(const ThreadOptions &thread_options, const string &name, std::function< void()> fn) override` <a class="md-anchor" id="Thread_tensorflow_EnvWrapper_StartThread"></a>
#### `Thread* tensorflow::EnvWrapper::StartThread(const ThreadOptions &thread_options, const string &name, std::function< void()> fn) override` {#Thread_tensorflow_EnvWrapper_StartThread}
Returns a new thread that is running fn() and is identified (for debugging/performance-analysis) by "name".

View File

@ -1,31 +1,31 @@
# Class `tensorflow::RandomAccessFile` <a class="md-anchor" id="AUTOGENERATED-class--tensorflow--randomaccessfile-"></a>
# Class `tensorflow::RandomAccessFile`
A file abstraction for randomly reading the contents of a file.
##Member Summary <a class="md-anchor" id="AUTOGENERATED-member-summary"></a>
##Member Summary
* [`tensorflow::RandomAccessFile::RandomAccessFile()`](#tensorflow_RandomAccessFile_RandomAccessFile)
* [`virtual tensorflow::RandomAccessFile::~RandomAccessFile()`](#virtual_tensorflow_RandomAccessFile_RandomAccessFile)
* [`virtual Status tensorflow::RandomAccessFile::Read(uint64 offset, size_t n, StringPiece *result, char *scratch) const =0`](#virtual_Status_tensorflow_RandomAccessFile_Read)
* Reads up to "n" bytes from the file starting at "offset".
##Member Details <a class="md-anchor" id="AUTOGENERATED-member-details"></a>
##Member Details
#### `tensorflow::RandomAccessFile::RandomAccessFile()` <a class="md-anchor" id="tensorflow_RandomAccessFile_RandomAccessFile"></a>
#### `tensorflow::RandomAccessFile::RandomAccessFile()` {#tensorflow_RandomAccessFile_RandomAccessFile}
#### `virtual tensorflow::RandomAccessFile::~RandomAccessFile()` <a class="md-anchor" id="virtual_tensorflow_RandomAccessFile_RandomAccessFile"></a>
#### `virtual tensorflow::RandomAccessFile::~RandomAccessFile()` {#virtual_tensorflow_RandomAccessFile_RandomAccessFile}
#### `virtual Status tensorflow::RandomAccessFile::Read(uint64 offset, size_t n, StringPiece *result, char *scratch) const =0` <a class="md-anchor" id="virtual_Status_tensorflow_RandomAccessFile_Read"></a>
#### `virtual Status tensorflow::RandomAccessFile::Read(uint64 offset, size_t n, StringPiece *result, char *scratch) const =0` {#virtual_Status_tensorflow_RandomAccessFile_Read}
Reads up to "n" bytes from the file starting at "offset".

View File

@ -1,4 +1,4 @@
# Class `tensorflow::Session` <a class="md-anchor" id="AUTOGENERATED-class--tensorflow--session-"></a>
# Class `tensorflow::Session`
A Session instance lets a caller drive a TensorFlow graph computation.
@ -41,7 +41,7 @@ A Session allows concurrent calls to Run() , though a Session must be created /
Only one thread must call Close() , and Close() must only be called after all other calls to Run() have returned.
##Member Summary <a class="md-anchor" id="AUTOGENERATED-member-summary"></a>
##Member Summary
* [`virtual Status tensorflow::Session::Create(const GraphDef &graph)=0`](#virtual_Status_tensorflow_Session_Create)
* Create the graph to be used for the session.
@ -53,21 +53,21 @@ Only one thread must call Close() , and Close() must only be called after all ot
* Closes this session.
* [`virtual tensorflow::Session::~Session()`](#virtual_tensorflow_Session_Session)
##Member Details <a class="md-anchor" id="AUTOGENERATED-member-details"></a>
##Member Details
#### `virtual Status tensorflow::Session::Create(const GraphDef &graph)=0` <a class="md-anchor" id="virtual_Status_tensorflow_Session_Create"></a>
#### `virtual Status tensorflow::Session::Create(const GraphDef &graph)=0` {#virtual_Status_tensorflow_Session_Create}
Create the graph to be used for the session.
Returns an error if this session has already been created with a graph. To re-use the session with a different graph, the caller must Close() the session first.
#### `virtual Status tensorflow::Session::Extend(const GraphDef &graph)=0` <a class="md-anchor" id="virtual_Status_tensorflow_Session_Extend"></a>
#### `virtual Status tensorflow::Session::Extend(const GraphDef &graph)=0` {#virtual_Status_tensorflow_Session_Extend}
Adds operations to the graph that is already registered with the Session .
The names of new operations in "graph" must not exist in the graph that is already registered.
#### `virtual Status tensorflow::Session::Run(const std::vector< std::pair< string, Tensor > > &inputs, const std::vector< string > &output_tensor_names, const std::vector< string > &target_node_names, std::vector< Tensor > *outputs)=0` <a class="md-anchor" id="virtual_Status_tensorflow_Session_Run"></a>
#### `virtual Status tensorflow::Session::Run(const std::vector< std::pair< string, Tensor > > &inputs, const std::vector< string > &output_tensor_names, const std::vector< string > &target_node_names, std::vector< Tensor > *outputs)=0` {#virtual_Status_tensorflow_Session_Run}
Runs the graph with the provided input tensors and fills `outputs` for the endpoints specified in `output_tensor_names`. Runs to but does not return Tensors for the nodes in `target_node_names`.
@ -79,13 +79,13 @@ REQUIRES: The name of each Tensor of the input or output must match a "Tensor en
REQUIRES: outputs is not nullptr if `output_tensor_names` is non-empty.
#### `virtual Status tensorflow::Session::Close()=0` <a class="md-anchor" id="virtual_Status_tensorflow_Session_Close"></a>
#### `virtual Status tensorflow::Session::Close()=0` {#virtual_Status_tensorflow_Session_Close}
Closes this session.
Closing a session releases the resources used by this session on the TensorFlow runtime (specified during session creation by the ` SessionOptions::target ` field).
#### `virtual tensorflow::Session::~Session()` <a class="md-anchor" id="virtual_tensorflow_Session_Session"></a>
#### `virtual tensorflow::Session::~Session()` {#virtual_tensorflow_Session_Session}

View File

@ -1,10 +1,10 @@
# Class `tensorflow::Status` <a class="md-anchor" id="AUTOGENERATED-class--tensorflow--status-"></a>
# Class `tensorflow::Status`
##Member Summary <a class="md-anchor" id="AUTOGENERATED-member-summary"></a>
##Member Summary
* [`tensorflow::Status::Status()`](#tensorflow_Status_Status)
* Create a success status.
@ -26,81 +26,81 @@
* Return a string representation of this status suitable for printing. Returns the string `"OK"` for success.
* [`static Status tensorflow::Status::OK()`](#static_Status_tensorflow_Status_OK)
##Member Details <a class="md-anchor" id="AUTOGENERATED-member-details"></a>
##Member Details
#### `tensorflow::Status::Status()` <a class="md-anchor" id="tensorflow_Status_Status"></a>
#### `tensorflow::Status::Status()` {#tensorflow_Status_Status}
Create a success status.
#### `tensorflow::Status::~Status()` <a class="md-anchor" id="tensorflow_Status_Status"></a>
#### `tensorflow::Status::~Status()` {#tensorflow_Status_Status}
#### `tensorflow::Status::Status(tensorflow::error::Code code, tensorflow::StringPiece msg)` <a class="md-anchor" id="tensorflow_Status_Status"></a>
#### `tensorflow::Status::Status(tensorflow::error::Code code, tensorflow::StringPiece msg)` {#tensorflow_Status_Status}
Create a status with the specified error code and msg as a human-readable string containing more detailed information.
#### `tensorflow::Status::Status(const Status &s)` <a class="md-anchor" id="tensorflow_Status_Status"></a>
#### `tensorflow::Status::Status(const Status &s)` {#tensorflow_Status_Status}
Copy the specified status.
#### `void tensorflow::Status::operator=(const Status &s)` <a class="md-anchor" id="void_tensorflow_Status_operator_"></a>
#### `void tensorflow::Status::operator=(const Status &s)` {#void_tensorflow_Status_operator_}
#### `bool tensorflow::Status::ok() const` <a class="md-anchor" id="bool_tensorflow_Status_ok"></a>
#### `bool tensorflow::Status::ok() const` {#bool_tensorflow_Status_ok}
Returns true iff the status indicates success.
#### `tensorflow::error::Code tensorflow::Status::code() const` <a class="md-anchor" id="tensorflow_error_Code_tensorflow_Status_code"></a>
#### `tensorflow::error::Code tensorflow::Status::code() const` {#tensorflow_error_Code_tensorflow_Status_code}
#### `const string& tensorflow::Status::error_message() const` <a class="md-anchor" id="const_string_tensorflow_Status_error_message"></a>
#### `const string& tensorflow::Status::error_message() const` {#const_string_tensorflow_Status_error_message}
#### `bool tensorflow::Status::operator==(const Status &x) const` <a class="md-anchor" id="bool_tensorflow_Status_operator_"></a>
#### `bool tensorflow::Status::operator==(const Status &x) const` {#bool_tensorflow_Status_operator_}
#### `bool tensorflow::Status::operator!=(const Status &x) const` <a class="md-anchor" id="bool_tensorflow_Status_operator_"></a>
#### `bool tensorflow::Status::operator!=(const Status &x) const` {#bool_tensorflow_Status_operator_}
#### `void tensorflow::Status::Update(const Status &new_status)` <a class="md-anchor" id="void_tensorflow_Status_Update"></a>
#### `void tensorflow::Status::Update(const Status &new_status)` {#void_tensorflow_Status_Update}
If ` ok() `, stores `new_status` into `*this`. If `!ok()`, preserves the current status, but may augment with additional information about `new_status`.
Convenient way of keeping track of the first error encountered. Instead of: `if (overall_status.ok()) overall_status = new_status` Use: `overall_status.Update(new_status);`
#### `string tensorflow::Status::ToString() const` <a class="md-anchor" id="string_tensorflow_Status_ToString"></a>
#### `string tensorflow::Status::ToString() const` {#string_tensorflow_Status_ToString}
Return a string representation of this status suitable for printing. Returns the string `"OK"` for success.
#### `static Status tensorflow::Status::OK()` <a class="md-anchor" id="static_Status_tensorflow_Status_OK"></a>
#### `static Status tensorflow::Status::OK()` {#static_Status_tensorflow_Status_OK}

View File

@ -1,10 +1,10 @@
# Class `tensorflow::Tensor` <a class="md-anchor" id="AUTOGENERATED-class--tensorflow--tensor-"></a>
# Class `tensorflow::Tensor`
Represents an n-dimensional array of values.
##Member Summary <a class="md-anchor" id="AUTOGENERATED-member-summary"></a>
##Member Summary
* [`tensorflow::Tensor::Tensor()`](#tensorflow_Tensor_Tensor)
* Default Tensor constructor. Creates a 1-dimension, 0-element float tensor.
@ -76,105 +76,105 @@ Represents an n-dimensional array of values.
* [`StringPiece tensorflow::Tensor::tensor_data() const`](#StringPiece_tensorflow_Tensor_tensor_data)
* Returns a `StringPiece` mapping the current tensor&apos;s buffer.
##Member Details <a class="md-anchor" id="AUTOGENERATED-member-details"></a>
##Member Details
#### `tensorflow::Tensor::Tensor()` <a class="md-anchor" id="tensorflow_Tensor_Tensor"></a>
#### `tensorflow::Tensor::Tensor()` {#tensorflow_Tensor_Tensor}
Default Tensor constructor. Creates a 1-dimension, 0-element float tensor.
#### `tensorflow::Tensor::Tensor(DataType type, const TensorShape &shape)` <a class="md-anchor" id="tensorflow_Tensor_Tensor"></a>
#### `tensorflow::Tensor::Tensor(DataType type, const TensorShape &shape)` {#tensorflow_Tensor_Tensor}
Creates a Tensor of the given `type` and `shape`.
The underlying buffer is allocated using a `CPUAllocator`.
#### `tensorflow::Tensor::Tensor(Allocator *a, DataType type, const TensorShape &shape)` <a class="md-anchor" id="tensorflow_Tensor_Tensor"></a>
#### `tensorflow::Tensor::Tensor(Allocator *a, DataType type, const TensorShape &shape)` {#tensorflow_Tensor_Tensor}
Creates a tensor with the input `type` and `shape`, using the allocator `a` to allocate the underlying buffer.
`a` must outlive the lifetime of this Tensor .
#### `tensorflow::Tensor::Tensor(DataType type)` <a class="md-anchor" id="tensorflow_Tensor_Tensor"></a>
#### `tensorflow::Tensor::Tensor(DataType type)` {#tensorflow_Tensor_Tensor}
Creates an uninitialized Tensor of the given data type.
#### `tensorflow::Tensor::Tensor(const Tensor &other)` <a class="md-anchor" id="tensorflow_Tensor_Tensor"></a>
#### `tensorflow::Tensor::Tensor(const Tensor &other)` {#tensorflow_Tensor_Tensor}
#### `tensorflow::Tensor::~Tensor()` <a class="md-anchor" id="tensorflow_Tensor_Tensor"></a>
#### `tensorflow::Tensor::~Tensor()` {#tensorflow_Tensor_Tensor}
Copy constructor.
#### `DataType tensorflow::Tensor::dtype() const` <a class="md-anchor" id="DataType_tensorflow_Tensor_dtype"></a>
#### `DataType tensorflow::Tensor::dtype() const` {#DataType_tensorflow_Tensor_dtype}
Returns the data type.
#### `const TensorShape& tensorflow::Tensor::shape() const` <a class="md-anchor" id="const_TensorShape_tensorflow_Tensor_shape"></a>
#### `const TensorShape& tensorflow::Tensor::shape() const` {#const_TensorShape_tensorflow_Tensor_shape}
Returns the shape of the tensor.
#### `int tensorflow::Tensor::dims() const` <a class="md-anchor" id="int_tensorflow_Tensor_dims"></a>
#### `int tensorflow::Tensor::dims() const` {#int_tensorflow_Tensor_dims}
Convenience accessor for the tensor shape.
For all shape accessors, see comments for relevant methods of ` TensorShape ` in ` tensor_shape.h `.
#### `int64 tensorflow::Tensor::dim_size(int d) const` <a class="md-anchor" id="int64_tensorflow_Tensor_dim_size"></a>
#### `int64 tensorflow::Tensor::dim_size(int d) const` {#int64_tensorflow_Tensor_dim_size}
Convenience accessor for the tensor shape.
#### `int64 tensorflow::Tensor::NumElements() const` <a class="md-anchor" id="int64_tensorflow_Tensor_NumElements"></a>
#### `int64 tensorflow::Tensor::NumElements() const` {#int64_tensorflow_Tensor_NumElements}
Convenience accessor for the tensor shape.
#### `bool tensorflow::Tensor::IsSameSize(const Tensor &b) const` <a class="md-anchor" id="bool_tensorflow_Tensor_IsSameSize"></a>
#### `bool tensorflow::Tensor::IsSameSize(const Tensor &b) const` {#bool_tensorflow_Tensor_IsSameSize}
#### `bool tensorflow::Tensor::IsInitialized() const` <a class="md-anchor" id="bool_tensorflow_Tensor_IsInitialized"></a>
#### `bool tensorflow::Tensor::IsInitialized() const` {#bool_tensorflow_Tensor_IsInitialized}
Has this Tensor been initialized?
#### `size_t tensorflow::Tensor::TotalBytes() const` <a class="md-anchor" id="size_t_tensorflow_Tensor_TotalBytes"></a>
#### `size_t tensorflow::Tensor::TotalBytes() const` {#size_t_tensorflow_Tensor_TotalBytes}
Returns the estimated memory usage of this tensor.
#### `Tensor& tensorflow::Tensor::operator=(const Tensor &other)` <a class="md-anchor" id="Tensor_tensorflow_Tensor_operator_"></a>
#### `Tensor& tensorflow::Tensor::operator=(const Tensor &other)` {#Tensor_tensorflow_Tensor_operator_}
Assign operator. This tensor shares other&apos;s underlying storage.
#### `bool tensorflow::Tensor::CopyFrom(const Tensor &other, const TensorShape &shape) TF_MUST_USE_RESULT` <a class="md-anchor" id="bool_tensorflow_Tensor_CopyFrom"></a>
#### `bool tensorflow::Tensor::CopyFrom(const Tensor &other, const TensorShape &shape) TF_MUST_USE_RESULT` {#bool_tensorflow_Tensor_CopyFrom}
Copy the other tensor into this tensor and reshape it.
This tensor shares other&apos;s underlying storage. Returns `true` iff `other.shape()` has the same number of elements of the given `shape`.
#### `Tensor tensorflow::Tensor::Slice(int64 dim0_start, int64 dim0_limit) const` <a class="md-anchor" id="Tensor_tensorflow_Tensor_Slice"></a>
#### `Tensor tensorflow::Tensor::Slice(int64 dim0_start, int64 dim0_limit) const` {#Tensor_tensorflow_Tensor_Slice}
Slice this tensor along the 1st dimension.
@ -184,31 +184,31 @@ NOTE: The returned tensor may not satisfies the same alignment requirement as th
REQUIRES: ` dims() ` >= 1 REQUIRES: `0 <= dim0_start <= dim0_limit <= dim_size(0)`
#### `bool tensorflow::Tensor::FromProto(const TensorProto &other) TF_MUST_USE_RESULT` <a class="md-anchor" id="bool_tensorflow_Tensor_FromProto"></a>
#### `bool tensorflow::Tensor::FromProto(const TensorProto &other) TF_MUST_USE_RESULT` {#bool_tensorflow_Tensor_FromProto}
Parse `other` and construct the tensor.
Returns `true` iff the parsing succeeds. If the parsing fails, the state of `*this` is unchanged.
#### `bool tensorflow::Tensor::FromProto(Allocator *a, const TensorProto &other) TF_MUST_USE_RESULT` <a class="md-anchor" id="bool_tensorflow_Tensor_FromProto"></a>
#### `bool tensorflow::Tensor::FromProto(Allocator *a, const TensorProto &other) TF_MUST_USE_RESULT` {#bool_tensorflow_Tensor_FromProto}
#### `void tensorflow::Tensor::AsProtoField(TensorProto *proto) const` <a class="md-anchor" id="void_tensorflow_Tensor_AsProtoField"></a>
#### `void tensorflow::Tensor::AsProtoField(TensorProto *proto) const` {#void_tensorflow_Tensor_AsProtoField}
Fills in `proto` with `*this` tensor&apos;s content.
` AsProtoField() ` fills in the repeated field for `proto.dtype()`, while `AsProtoTensorContent()` encodes the content in `proto.tensor_content()` in a compact form.
#### `void tensorflow::Tensor::AsProtoTensorContent(TensorProto *proto) const` <a class="md-anchor" id="void_tensorflow_Tensor_AsProtoTensorContent"></a>
#### `void tensorflow::Tensor::AsProtoTensorContent(TensorProto *proto) const` {#void_tensorflow_Tensor_AsProtoTensorContent}
#### `TTypes<T>::Vec tensorflow::Tensor::vec()` <a class="md-anchor" id="TTypes_T_Vec_tensorflow_Tensor_vec"></a>
#### `TTypes<T>::Vec tensorflow::Tensor::vec()` {#TTypes_T_Vec_tensorflow_Tensor_vec}
Return the tensor data as an `Eigen::Tensor` with the type and sizes of this ` Tensor `.
@ -226,19 +226,19 @@ auto mat = my_mat.matrix<int32>();// CHECK fails as type mismatch.
```
#### `TTypes<T>::Matrix tensorflow::Tensor::matrix()` <a class="md-anchor" id="TTypes_T_Matrix_tensorflow_Tensor_matrix"></a>
#### `TTypes<T>::Matrix tensorflow::Tensor::matrix()` {#TTypes_T_Matrix_tensorflow_Tensor_matrix}
#### `TTypes< T, NDIMS >::Tensor tensorflow::Tensor::tensor()` <a class="md-anchor" id="TTypes_T_NDIMS_Tensor_tensorflow_Tensor_tensor"></a>
#### `TTypes< T, NDIMS >::Tensor tensorflow::Tensor::tensor()` {#TTypes_T_NDIMS_Tensor_tensorflow_Tensor_tensor}
#### `TTypes<T>::Flat tensorflow::Tensor::flat()` <a class="md-anchor" id="TTypes_T_Flat_tensorflow_Tensor_flat"></a>
#### `TTypes<T>::Flat tensorflow::Tensor::flat()` {#TTypes_T_Flat_tensorflow_Tensor_flat}
Return the tensor data as an `Eigen::Tensor` of the data type and a specified shape.
@ -263,121 +263,121 @@ auto bad = my_ten.flat<int32>();
```
#### `TTypes<T>::UnalignedFlat tensorflow::Tensor::unaligned_flat()` <a class="md-anchor" id="TTypes_T_UnalignedFlat_tensorflow_Tensor_unaligned_flat"></a>
#### `TTypes<T>::UnalignedFlat tensorflow::Tensor::unaligned_flat()` {#TTypes_T_UnalignedFlat_tensorflow_Tensor_unaligned_flat}
#### `TTypes<T>::Matrix tensorflow::Tensor::flat_inner_dims()` <a class="md-anchor" id="TTypes_T_Matrix_tensorflow_Tensor_flat_inner_dims"></a>
#### `TTypes<T>::Matrix tensorflow::Tensor::flat_inner_dims()` {#TTypes_T_Matrix_tensorflow_Tensor_flat_inner_dims}
Returns the data as an Eigen::Tensor with 2 dimensions, collapsing all Tensor dimensions but the last one into the first dimension of the result.
#### `TTypes<T>::Matrix tensorflow::Tensor::flat_outer_dims()` <a class="md-anchor" id="TTypes_T_Matrix_tensorflow_Tensor_flat_outer_dims"></a>
#### `TTypes<T>::Matrix tensorflow::Tensor::flat_outer_dims()` {#TTypes_T_Matrix_tensorflow_Tensor_flat_outer_dims}
Returns the data as an Eigen::Tensor with 2 dimensions, collapsing all Tensor dimensions but the first one into the last dimension of the result.
#### `TTypes< T, NDIMS >::Tensor tensorflow::Tensor::shaped(gtl::ArraySlice< int64 > new_sizes)` <a class="md-anchor" id="TTypes_T_NDIMS_Tensor_tensorflow_Tensor_shaped"></a>
#### `TTypes< T, NDIMS >::Tensor tensorflow::Tensor::shaped(gtl::ArraySlice< int64 > new_sizes)` {#TTypes_T_NDIMS_Tensor_tensorflow_Tensor_shaped}
#### `TTypes< T, NDIMS >::UnalignedTensor tensorflow::Tensor::unaligned_shaped(gtl::ArraySlice< int64 > new_sizes)` <a class="md-anchor" id="TTypes_T_NDIMS_UnalignedTensor_tensorflow_Tensor_unaligned_shaped"></a>
#### `TTypes< T, NDIMS >::UnalignedTensor tensorflow::Tensor::unaligned_shaped(gtl::ArraySlice< int64 > new_sizes)` {#TTypes_T_NDIMS_UnalignedTensor_tensorflow_Tensor_unaligned_shaped}
#### `TTypes< T >::Scalar tensorflow::Tensor::scalar()` <a class="md-anchor" id="TTypes_T_Scalar_tensorflow_Tensor_scalar"></a>
#### `TTypes< T >::Scalar tensorflow::Tensor::scalar()` {#TTypes_T_Scalar_tensorflow_Tensor_scalar}
Return the Tensor data as a `TensorMap` of fixed size 1: `TensorMap<TensorFixedSize<T, 1>>`.
Using ` scalar() ` allows the compiler to perform optimizations as the size of the tensor is known at compile time.
#### `TTypes<T>::ConstVec tensorflow::Tensor::vec() const` <a class="md-anchor" id="TTypes_T_ConstVec_tensorflow_Tensor_vec"></a>
#### `TTypes<T>::ConstVec tensorflow::Tensor::vec() const` {#TTypes_T_ConstVec_tensorflow_Tensor_vec}
Const versions of all the methods above.
#### `TTypes<T>::ConstMatrix tensorflow::Tensor::matrix() const` <a class="md-anchor" id="TTypes_T_ConstMatrix_tensorflow_Tensor_matrix"></a>
#### `TTypes<T>::ConstMatrix tensorflow::Tensor::matrix() const` {#TTypes_T_ConstMatrix_tensorflow_Tensor_matrix}
#### `TTypes< T, NDIMS >::ConstTensor tensorflow::Tensor::tensor() const` <a class="md-anchor" id="TTypes_T_NDIMS_ConstTensor_tensorflow_Tensor_tensor"></a>
#### `TTypes< T, NDIMS >::ConstTensor tensorflow::Tensor::tensor() const` {#TTypes_T_NDIMS_ConstTensor_tensorflow_Tensor_tensor}
#### `TTypes<T>::ConstFlat tensorflow::Tensor::flat() const` <a class="md-anchor" id="TTypes_T_ConstFlat_tensorflow_Tensor_flat"></a>
#### `TTypes<T>::ConstFlat tensorflow::Tensor::flat() const` {#TTypes_T_ConstFlat_tensorflow_Tensor_flat}
#### `TTypes<T>::UnalignedConstFlat tensorflow::Tensor::unaligned_flat() const` <a class="md-anchor" id="TTypes_T_UnalignedConstFlat_tensorflow_Tensor_unaligned_flat"></a>
#### `TTypes<T>::UnalignedConstFlat tensorflow::Tensor::unaligned_flat() const` {#TTypes_T_UnalignedConstFlat_tensorflow_Tensor_unaligned_flat}
#### `TTypes<T>::ConstMatrix tensorflow::Tensor::flat_inner_dims() const` <a class="md-anchor" id="TTypes_T_ConstMatrix_tensorflow_Tensor_flat_inner_dims"></a>
#### `TTypes<T>::ConstMatrix tensorflow::Tensor::flat_inner_dims() const` {#TTypes_T_ConstMatrix_tensorflow_Tensor_flat_inner_dims}
#### `TTypes<T>::ConstMatrix tensorflow::Tensor::flat_outer_dims() const` <a class="md-anchor" id="TTypes_T_ConstMatrix_tensorflow_Tensor_flat_outer_dims"></a>
#### `TTypes<T>::ConstMatrix tensorflow::Tensor::flat_outer_dims() const` {#TTypes_T_ConstMatrix_tensorflow_Tensor_flat_outer_dims}
#### `TTypes< T, NDIMS >::ConstTensor tensorflow::Tensor::shaped(gtl::ArraySlice< int64 > new_sizes) const` <a class="md-anchor" id="TTypes_T_NDIMS_ConstTensor_tensorflow_Tensor_shaped"></a>
#### `TTypes< T, NDIMS >::ConstTensor tensorflow::Tensor::shaped(gtl::ArraySlice< int64 > new_sizes) const` {#TTypes_T_NDIMS_ConstTensor_tensorflow_Tensor_shaped}
#### `TTypes< T, NDIMS >::UnalignedConstTensor tensorflow::Tensor::unaligned_shaped(gtl::ArraySlice< int64 > new_sizes) const` <a class="md-anchor" id="TTypes_T_NDIMS_UnalignedConstTensor_tensorflow_Tensor_unaligned_shaped"></a>
#### `TTypes< T, NDIMS >::UnalignedConstTensor tensorflow::Tensor::unaligned_shaped(gtl::ArraySlice< int64 > new_sizes) const` {#TTypes_T_NDIMS_UnalignedConstTensor_tensorflow_Tensor_unaligned_shaped}
#### `TTypes< T >::ConstScalar tensorflow::Tensor::scalar() const` <a class="md-anchor" id="TTypes_T_ConstScalar_tensorflow_Tensor_scalar"></a>
#### `TTypes< T >::ConstScalar tensorflow::Tensor::scalar() const` {#TTypes_T_ConstScalar_tensorflow_Tensor_scalar}
#### `string tensorflow::Tensor::SummarizeValue(int64 max_entries) const` <a class="md-anchor" id="string_tensorflow_Tensor_SummarizeValue"></a>
#### `string tensorflow::Tensor::SummarizeValue(int64 max_entries) const` {#string_tensorflow_Tensor_SummarizeValue}
Render the first `max_entries` values in `*this` into a string.
#### `string tensorflow::Tensor::DebugString() const` <a class="md-anchor" id="string_tensorflow_Tensor_DebugString"></a>
#### `string tensorflow::Tensor::DebugString() const` {#string_tensorflow_Tensor_DebugString}
A human-readable summary of the tensor suitable for debugging.
#### `void tensorflow::Tensor::FillDescription(TensorDescription *description) const` <a class="md-anchor" id="void_tensorflow_Tensor_FillDescription"></a>
#### `void tensorflow::Tensor::FillDescription(TensorDescription *description) const` {#void_tensorflow_Tensor_FillDescription}
Fill in the `TensorDescription` proto with metadata about the tensor that is useful for monitoring and debugging.
#### `StringPiece tensorflow::Tensor::tensor_data() const` <a class="md-anchor" id="StringPiece_tensorflow_Tensor_tensor_data"></a>
#### `StringPiece tensorflow::Tensor::tensor_data() const` {#StringPiece_tensorflow_Tensor_tensor_data}
Returns a `StringPiece` mapping the current tensor&apos;s buffer.

View File

@ -1,10 +1,10 @@
# Class `tensorflow::TensorShape` <a class="md-anchor" id="AUTOGENERATED-class--tensorflow--tensorshape-"></a>
# Class `tensorflow::TensorShape`
Manages the dimensions of a Tensor and their sizes.
##Member Summary <a class="md-anchor" id="AUTOGENERATED-member-summary"></a>
##Member Summary
* [`tensorflow::TensorShape::TensorShape(gtl::ArraySlice< int64 > dim_sizes)`](#tensorflow_TensorShape_TensorShape)
* Construct a ` TensorShape ` from the provided sizes. REQUIRES: `dim_sizes[i] >= 0`
@ -48,147 +48,147 @@ Manages the dimensions of a Tensor and their sizes.
* [`static bool tensorflow::TensorShape::IsValid(const TensorShapeProto &proto)`](#static_bool_tensorflow_TensorShape_IsValid)
* Returns `true` iff `proto` is a valid tensor shape.
##Member Details <a class="md-anchor" id="AUTOGENERATED-member-details"></a>
##Member Details
#### `tensorflow::TensorShape::TensorShape(gtl::ArraySlice< int64 > dim_sizes)` <a class="md-anchor" id="tensorflow_TensorShape_TensorShape"></a>
#### `tensorflow::TensorShape::TensorShape(gtl::ArraySlice< int64 > dim_sizes)` {#tensorflow_TensorShape_TensorShape}
Construct a ` TensorShape ` from the provided sizes. REQUIRES: `dim_sizes[i] >= 0`
#### `tensorflow::TensorShape::TensorShape(std::initializer_list< int64 > dim_sizes)` <a class="md-anchor" id="tensorflow_TensorShape_TensorShape"></a>
#### `tensorflow::TensorShape::TensorShape(std::initializer_list< int64 > dim_sizes)` {#tensorflow_TensorShape_TensorShape}
#### `tensorflow::TensorShape::TensorShape(const TensorShapeProto &proto)` <a class="md-anchor" id="tensorflow_TensorShape_TensorShape"></a>
#### `tensorflow::TensorShape::TensorShape(const TensorShapeProto &proto)` {#tensorflow_TensorShape_TensorShape}
REQUIRES: `IsValid(proto)`
#### `tensorflow::TensorShape::TensorShape()` <a class="md-anchor" id="tensorflow_TensorShape_TensorShape"></a>
#### `tensorflow::TensorShape::TensorShape()` {#tensorflow_TensorShape_TensorShape}
Create a tensor shape with no dimensions and one element, which you can then call ` AddDim() ` on.
#### `void tensorflow::TensorShape::Clear()` <a class="md-anchor" id="void_tensorflow_TensorShape_Clear"></a>
#### `void tensorflow::TensorShape::Clear()` {#void_tensorflow_TensorShape_Clear}
Clear a tensor shape.
#### `void tensorflow::TensorShape::AddDim(int64 size)` <a class="md-anchor" id="void_tensorflow_TensorShape_AddDim"></a>
#### `void tensorflow::TensorShape::AddDim(int64 size)` {#void_tensorflow_TensorShape_AddDim}
Add a dimension to the end ("inner-most"). REQUIRES: `size >= 0`
#### `void tensorflow::TensorShape::AppendShape(const TensorShape &shape)` <a class="md-anchor" id="void_tensorflow_TensorShape_AppendShape"></a>
#### `void tensorflow::TensorShape::AppendShape(const TensorShape &shape)` {#void_tensorflow_TensorShape_AppendShape}
Appends all the dimensions from `shape`.
#### `void tensorflow::TensorShape::InsertDim(int d, int64 size)` <a class="md-anchor" id="void_tensorflow_TensorShape_InsertDim"></a>
#### `void tensorflow::TensorShape::InsertDim(int d, int64 size)` {#void_tensorflow_TensorShape_InsertDim}
Insert a dimension somewhere in the ` TensorShape `. REQUIRES: `0 <= d <= dims() ` REQUIRES: `size >= 0`
#### `void tensorflow::TensorShape::set_dim(int d, int64 size)` <a class="md-anchor" id="void_tensorflow_TensorShape_set_dim"></a>
#### `void tensorflow::TensorShape::set_dim(int d, int64 size)` {#void_tensorflow_TensorShape_set_dim}
Modifies the size of the dimension `d` to be `size` REQUIRES: `0 <= d < dims() ` REQUIRES: `size >= 0`
#### `void tensorflow::TensorShape::RemoveDim(int d)` <a class="md-anchor" id="void_tensorflow_TensorShape_RemoveDim"></a>
#### `void tensorflow::TensorShape::RemoveDim(int d)` {#void_tensorflow_TensorShape_RemoveDim}
Removes dimension `d` from the ` TensorShape `. REQUIRES: `0 <= d < dims() `
#### `int tensorflow::TensorShape::dims() const` <a class="md-anchor" id="int_tensorflow_TensorShape_dims"></a>
#### `int tensorflow::TensorShape::dims() const` {#int_tensorflow_TensorShape_dims}
Return the number of dimensions in the tensor.
#### `int64 tensorflow::TensorShape::dim_size(int d) const` <a class="md-anchor" id="int64_tensorflow_TensorShape_dim_size"></a>
#### `int64 tensorflow::TensorShape::dim_size(int d) const` {#int64_tensorflow_TensorShape_dim_size}
Returns the number of elements in dimension `d`. REQUIRES: `0 <= d < dims() `
#### `gtl::ArraySlice<int64> tensorflow::TensorShape::dim_sizes() const` <a class="md-anchor" id="gtl_ArraySlice_int64_tensorflow_TensorShape_dim_sizes"></a>
#### `gtl::ArraySlice<int64> tensorflow::TensorShape::dim_sizes() const` {#gtl_ArraySlice_int64_tensorflow_TensorShape_dim_sizes}
Returns sizes of all dimensions.
#### `int64 tensorflow::TensorShape::num_elements() const` <a class="md-anchor" id="int64_tensorflow_TensorShape_num_elements"></a>
#### `int64 tensorflow::TensorShape::num_elements() const` {#int64_tensorflow_TensorShape_num_elements}
Returns the number of elements in the tensor.
We use `int64` and not `size_t` to be compatible with `Eigen::Tensor` which uses `ptrdiff_t`.
#### `bool tensorflow::TensorShape::IsSameSize(const TensorShape &b) const` <a class="md-anchor" id="bool_tensorflow_TensorShape_IsSameSize"></a>
#### `bool tensorflow::TensorShape::IsSameSize(const TensorShape &b) const` {#bool_tensorflow_TensorShape_IsSameSize}
Returns true if `*this` and `b` have the same sizes. Ignores dimension names.
#### `bool tensorflow::TensorShape::operator==(const TensorShape &b) const` <a class="md-anchor" id="bool_tensorflow_TensorShape_operator_"></a>
#### `bool tensorflow::TensorShape::operator==(const TensorShape &b) const` {#bool_tensorflow_TensorShape_operator_}
#### `void tensorflow::TensorShape::AsProto(TensorShapeProto *proto) const` <a class="md-anchor" id="void_tensorflow_TensorShape_AsProto"></a>
#### `void tensorflow::TensorShape::AsProto(TensorShapeProto *proto) const` {#void_tensorflow_TensorShape_AsProto}
Fill `*proto` from `*this`.
#### `Eigen::DSizes< Eigen::DenseIndex, NDIMS > tensorflow::TensorShape::AsEigenDSizes() const` <a class="md-anchor" id="Eigen_DSizes_Eigen_DenseIndex_NDIMS_tensorflow_TensorShape_AsEigenDSizes"></a>
#### `Eigen::DSizes< Eigen::DenseIndex, NDIMS > tensorflow::TensorShape::AsEigenDSizes() const` {#Eigen_DSizes_Eigen_DenseIndex_NDIMS_tensorflow_TensorShape_AsEigenDSizes}
Fill `*dsizes` from `*this`.
#### `Eigen::DSizes< Eigen::DenseIndex, NDIMS > tensorflow::TensorShape::AsEigenDSizesWithPadding() const` <a class="md-anchor" id="Eigen_DSizes_Eigen_DenseIndex_NDIMS_tensorflow_TensorShape_AsEigenDSizesWithPadding"></a>
#### `Eigen::DSizes< Eigen::DenseIndex, NDIMS > tensorflow::TensorShape::AsEigenDSizesWithPadding() const` {#Eigen_DSizes_Eigen_DenseIndex_NDIMS_tensorflow_TensorShape_AsEigenDSizesWithPadding}
Same as ` AsEigenDSizes() ` but allows for `NDIMS > dims() ` in which case we pad the rest of the sizes with 1.
#### `TensorShapeIter tensorflow::TensorShape::begin() const` <a class="md-anchor" id="TensorShapeIter_tensorflow_TensorShape_begin"></a>
#### `TensorShapeIter tensorflow::TensorShape::begin() const` {#TensorShapeIter_tensorflow_TensorShape_begin}
For iterating through the dimensions.
#### `TensorShapeIter tensorflow::TensorShape::end() const` <a class="md-anchor" id="TensorShapeIter_tensorflow_TensorShape_end"></a>
#### `TensorShapeIter tensorflow::TensorShape::end() const` {#TensorShapeIter_tensorflow_TensorShape_end}
#### `string tensorflow::TensorShape::DebugString() const` <a class="md-anchor" id="string_tensorflow_TensorShape_DebugString"></a>
#### `string tensorflow::TensorShape::DebugString() const` {#string_tensorflow_TensorShape_DebugString}
For error messages.
#### `string tensorflow::TensorShape::ShortDebugString() const` <a class="md-anchor" id="string_tensorflow_TensorShape_ShortDebugString"></a>
#### `string tensorflow::TensorShape::ShortDebugString() const` {#string_tensorflow_TensorShape_ShortDebugString}
#### `static bool tensorflow::TensorShape::IsValid(const TensorShapeProto &proto)` <a class="md-anchor" id="static_bool_tensorflow_TensorShape_IsValid"></a>
#### `static bool tensorflow::TensorShape::IsValid(const TensorShapeProto &proto)` {#static_bool_tensorflow_TensorShape_IsValid}
Returns `true` iff `proto` is a valid tensor shape.

View File

@ -1,10 +1,10 @@
# Class `tensorflow::TensorShapeUtils` <a class="md-anchor" id="AUTOGENERATED-class--tensorflow--tensorshapeutils-"></a>
# Class `tensorflow::TensorShapeUtils`
Static helper routines for ` TensorShape `. Includes a few common predicates on a tensor shape.
##Member Summary <a class="md-anchor" id="AUTOGENERATED-member-summary"></a>
##Member Summary
* [`static bool tensorflow::TensorShapeUtils::IsScalar(const TensorShape &shape)`](#static_bool_tensorflow_TensorShapeUtils_IsScalar)
* [`static bool tensorflow::TensorShapeUtils::IsVector(const TensorShape &shape)`](#static_bool_tensorflow_TensorShapeUtils_IsVector)
@ -18,63 +18,63 @@ Static helper routines for ` TensorShape `. Includes a few common predicates on
* [`static string tensorflow::TensorShapeUtils::ShapeListString(const gtl::ArraySlice< TensorShape > &shapes)`](#static_string_tensorflow_TensorShapeUtils_ShapeListString)
* [`static bool tensorflow::TensorShapeUtils::StartsWith(const TensorShape &shape0, const TensorShape &shape1)`](#static_bool_tensorflow_TensorShapeUtils_StartsWith)
##Member Details <a class="md-anchor" id="AUTOGENERATED-member-details"></a>
##Member Details
#### `static bool tensorflow::TensorShapeUtils::IsScalar(const TensorShape &shape)` <a class="md-anchor" id="static_bool_tensorflow_TensorShapeUtils_IsScalar"></a>
#### `static bool tensorflow::TensorShapeUtils::IsScalar(const TensorShape &shape)` {#static_bool_tensorflow_TensorShapeUtils_IsScalar}
#### `static bool tensorflow::TensorShapeUtils::IsVector(const TensorShape &shape)` <a class="md-anchor" id="static_bool_tensorflow_TensorShapeUtils_IsVector"></a>
#### `static bool tensorflow::TensorShapeUtils::IsVector(const TensorShape &shape)` {#static_bool_tensorflow_TensorShapeUtils_IsVector}
#### `static bool tensorflow::TensorShapeUtils::IsLegacyScalar(const TensorShape &shape)` <a class="md-anchor" id="static_bool_tensorflow_TensorShapeUtils_IsLegacyScalar"></a>
#### `static bool tensorflow::TensorShapeUtils::IsLegacyScalar(const TensorShape &shape)` {#static_bool_tensorflow_TensorShapeUtils_IsLegacyScalar}
#### `static bool tensorflow::TensorShapeUtils::IsLegacyVector(const TensorShape &shape)` <a class="md-anchor" id="static_bool_tensorflow_TensorShapeUtils_IsLegacyVector"></a>
#### `static bool tensorflow::TensorShapeUtils::IsLegacyVector(const TensorShape &shape)` {#static_bool_tensorflow_TensorShapeUtils_IsLegacyVector}
#### `static bool tensorflow::TensorShapeUtils::IsVectorOrHigher(const TensorShape &shape)` <a class="md-anchor" id="static_bool_tensorflow_TensorShapeUtils_IsVectorOrHigher"></a>
#### `static bool tensorflow::TensorShapeUtils::IsVectorOrHigher(const TensorShape &shape)` {#static_bool_tensorflow_TensorShapeUtils_IsVectorOrHigher}
#### `static bool tensorflow::TensorShapeUtils::IsMatrix(const TensorShape &shape)` <a class="md-anchor" id="static_bool_tensorflow_TensorShapeUtils_IsMatrix"></a>
#### `static bool tensorflow::TensorShapeUtils::IsMatrix(const TensorShape &shape)` {#static_bool_tensorflow_TensorShapeUtils_IsMatrix}
#### `static bool tensorflow::TensorShapeUtils::IsMatrixOrHigher(const TensorShape &shape)` <a class="md-anchor" id="static_bool_tensorflow_TensorShapeUtils_IsMatrixOrHigher"></a>
#### `static bool tensorflow::TensorShapeUtils::IsMatrixOrHigher(const TensorShape &shape)` {#static_bool_tensorflow_TensorShapeUtils_IsMatrixOrHigher}
#### `static TensorShape tensorflow::TensorShapeUtils::MakeShape(const T *dims, int n)` <a class="md-anchor" id="static_TensorShape_tensorflow_TensorShapeUtils_MakeShape"></a>
#### `static TensorShape tensorflow::TensorShapeUtils::MakeShape(const T *dims, int n)` {#static_TensorShape_tensorflow_TensorShapeUtils_MakeShape}
Returns a ` TensorShape ` whose dimensions are `dims[0]`, `dims[1]`, ..., `dims[n-1]`.
#### `static string tensorflow::TensorShapeUtils::ShapeListString(const gtl::ArraySlice< TensorShape > &shapes)` <a class="md-anchor" id="static_string_tensorflow_TensorShapeUtils_ShapeListString"></a>
#### `static string tensorflow::TensorShapeUtils::ShapeListString(const gtl::ArraySlice< TensorShape > &shapes)` {#static_string_tensorflow_TensorShapeUtils_ShapeListString}
#### `static bool tensorflow::TensorShapeUtils::StartsWith(const TensorShape &shape0, const TensorShape &shape1)` <a class="md-anchor" id="static_bool_tensorflow_TensorShapeUtils_StartsWith"></a>
#### `static bool tensorflow::TensorShapeUtils::StartsWith(const TensorShape &shape0, const TensorShape &shape1)` {#static_bool_tensorflow_TensorShapeUtils_StartsWith}

View File

@ -1,24 +1,24 @@
# Class `tensorflow::Thread` <a class="md-anchor" id="AUTOGENERATED-class--tensorflow--thread-"></a>
# Class `tensorflow::Thread`
##Member Summary <a class="md-anchor" id="AUTOGENERATED-member-summary"></a>
##Member Summary
* [`tensorflow::Thread::Thread()`](#tensorflow_Thread_Thread)
* [`virtual tensorflow::Thread::~Thread()`](#virtual_tensorflow_Thread_Thread)
* Blocks until the thread of control stops running.
##Member Details <a class="md-anchor" id="AUTOGENERATED-member-details"></a>
##Member Details
#### `tensorflow::Thread::Thread()` <a class="md-anchor" id="tensorflow_Thread_Thread"></a>
#### `tensorflow::Thread::Thread()` {#tensorflow_Thread_Thread}
#### `virtual tensorflow::Thread::~Thread()` <a class="md-anchor" id="virtual_tensorflow_Thread_Thread"></a>
#### `virtual tensorflow::Thread::~Thread()` {#virtual_tensorflow_Thread_Thread}
Blocks until the thread of control stops running.

View File

@ -1,10 +1,10 @@
# Class `tensorflow::WritableFile` <a class="md-anchor" id="AUTOGENERATED-class--tensorflow--writablefile-"></a>
# Class `tensorflow::WritableFile`
A file abstraction for sequential writing.
The implementation must provide buffering since callers may append small fragments at a time to the file.
##Member Summary <a class="md-anchor" id="AUTOGENERATED-member-summary"></a>
##Member Summary
* [`tensorflow::WritableFile::WritableFile()`](#tensorflow_WritableFile_WritableFile)
* [`virtual tensorflow::WritableFile::~WritableFile()`](#virtual_tensorflow_WritableFile_WritableFile)
@ -13,39 +13,39 @@ The implementation must provide buffering since callers may append small fragmen
* [`virtual Status tensorflow::WritableFile::Flush()=0`](#virtual_Status_tensorflow_WritableFile_Flush)
* [`virtual Status tensorflow::WritableFile::Sync()=0`](#virtual_Status_tensorflow_WritableFile_Sync)
##Member Details <a class="md-anchor" id="AUTOGENERATED-member-details"></a>
##Member Details
#### `tensorflow::WritableFile::WritableFile()` <a class="md-anchor" id="tensorflow_WritableFile_WritableFile"></a>
#### `tensorflow::WritableFile::WritableFile()` {#tensorflow_WritableFile_WritableFile}
#### `virtual tensorflow::WritableFile::~WritableFile()` <a class="md-anchor" id="virtual_tensorflow_WritableFile_WritableFile"></a>
#### `virtual tensorflow::WritableFile::~WritableFile()` {#virtual_tensorflow_WritableFile_WritableFile}
#### `virtual Status tensorflow::WritableFile::Append(const StringPiece &data)=0` <a class="md-anchor" id="virtual_Status_tensorflow_WritableFile_Append"></a>
#### `virtual Status tensorflow::WritableFile::Append(const StringPiece &data)=0` {#virtual_Status_tensorflow_WritableFile_Append}
#### `virtual Status tensorflow::WritableFile::Close()=0` <a class="md-anchor" id="virtual_Status_tensorflow_WritableFile_Close"></a>
#### `virtual Status tensorflow::WritableFile::Close()=0` {#virtual_Status_tensorflow_WritableFile_Close}
#### `virtual Status tensorflow::WritableFile::Flush()=0` <a class="md-anchor" id="virtual_Status_tensorflow_WritableFile_Flush"></a>
#### `virtual Status tensorflow::WritableFile::Flush()=0` {#virtual_Status_tensorflow_WritableFile_Flush}
#### `virtual Status tensorflow::WritableFile::Sync()=0` <a class="md-anchor" id="virtual_Status_tensorflow_WritableFile_Sync"></a>
#### `virtual Status tensorflow::WritableFile::Sync()=0` {#virtual_Status_tensorflow_WritableFile_Sync}

View File

@ -1,10 +1,10 @@
# Struct `tensorflow::SessionOptions` <a class="md-anchor" id="AUTOGENERATED-struct--tensorflow--sessionoptions-"></a>
# Struct `tensorflow::SessionOptions`
Configuration information for a Session .
##Member Summary <a class="md-anchor" id="AUTOGENERATED-member-summary"></a>
##Member Summary
* [`Env* tensorflow::SessionOptions::env`](#Env_tensorflow_SessionOptions_env)
* The environment to use.
@ -14,15 +14,15 @@ Configuration information for a Session .
* Configuration options.
* [`tensorflow::SessionOptions::SessionOptions()`](#tensorflow_SessionOptions_SessionOptions)
##Member Details <a class="md-anchor" id="AUTOGENERATED-member-details"></a>
##Member Details
#### `Env* tensorflow::SessionOptions::env` <a class="md-anchor" id="Env_tensorflow_SessionOptions_env"></a>
#### `Env* tensorflow::SessionOptions::env` {#Env_tensorflow_SessionOptions_env}
The environment to use.
#### `string tensorflow::SessionOptions::target` <a class="md-anchor" id="string_tensorflow_SessionOptions_target"></a>
#### `string tensorflow::SessionOptions::target` {#string_tensorflow_SessionOptions_target}
The TensorFlow runtime to connect to.
@ -36,13 +36,13 @@ Upon creation, a single session affines itself to one of the remote processes, w
If the session disconnects from the remote process during its lifetime, session calls may fail immediately.
#### `ConfigProto tensorflow::SessionOptions::config` <a class="md-anchor" id="ConfigProto_tensorflow_SessionOptions_config"></a>
#### `ConfigProto tensorflow::SessionOptions::config` {#ConfigProto_tensorflow_SessionOptions_config}
Configuration options.
#### `tensorflow::SessionOptions::SessionOptions()` <a class="md-anchor" id="tensorflow_SessionOptions_SessionOptions"></a>
#### `tensorflow::SessionOptions::SessionOptions()` {#tensorflow_SessionOptions_SessionOptions}

View File

@ -1,23 +1,23 @@
# Struct `tensorflow::Status::State` <a class="md-anchor" id="AUTOGENERATED-struct--tensorflow--status--state-"></a>
# Struct `tensorflow::Status::State`
##Member Summary <a class="md-anchor" id="AUTOGENERATED-member-summary"></a>
##Member Summary
* [`tensorflow::error::Code tensorflow::Status::State::code`](#tensorflow_error_Code_tensorflow_Status_State_code)
* [`string tensorflow::Status::State::msg`](#string_tensorflow_Status_State_msg)
##Member Details <a class="md-anchor" id="AUTOGENERATED-member-details"></a>
##Member Details
#### `tensorflow::error::Code tensorflow::Status::State::code` <a class="md-anchor" id="tensorflow_error_Code_tensorflow_Status_State_code"></a>
#### `tensorflow::error::Code tensorflow::Status::State::code` {#tensorflow_error_Code_tensorflow_Status_State_code}
#### `string tensorflow::Status::State::msg` <a class="md-anchor" id="string_tensorflow_Status_State_msg"></a>
#### `string tensorflow::Status::State::msg` {#string_tensorflow_Status_State_msg}

View File

@ -1,23 +1,23 @@
# Struct `tensorflow::TensorShapeDim` <a class="md-anchor" id="AUTOGENERATED-struct--tensorflow--tensorshapedim-"></a>
# Struct `tensorflow::TensorShapeDim`
##Member Summary <a class="md-anchor" id="AUTOGENERATED-member-summary"></a>
##Member Summary
* [`int tensorflow::TensorShapeDim::size`](#int_tensorflow_TensorShapeDim_size)
* [`tensorflow::TensorShapeDim::TensorShapeDim(int64 s)`](#tensorflow_TensorShapeDim_TensorShapeDim)
##Member Details <a class="md-anchor" id="AUTOGENERATED-member-details"></a>
##Member Details
#### `int tensorflow::TensorShapeDim::size` <a class="md-anchor" id="int_tensorflow_TensorShapeDim_size"></a>
#### `int tensorflow::TensorShapeDim::size` {#int_tensorflow_TensorShapeDim_size}
#### `tensorflow::TensorShapeDim::TensorShapeDim(int64 s)` <a class="md-anchor" id="tensorflow_TensorShapeDim_TensorShapeDim"></a>
#### `tensorflow::TensorShapeDim::TensorShapeDim(int64 s)` {#tensorflow_TensorShapeDim_TensorShapeDim}

View File

@ -1,25 +1,25 @@
# Struct `tensorflow::ThreadOptions` <a class="md-anchor" id="AUTOGENERATED-struct--tensorflow--threadoptions-"></a>
# Struct `tensorflow::ThreadOptions`
Options to configure a Thread .
Note that the options are all hints, and the underlying implementation may choose to ignore it.
##Member Summary <a class="md-anchor" id="AUTOGENERATED-member-summary"></a>
##Member Summary
* [`size_t tensorflow::ThreadOptions::stack_size`](#size_t_tensorflow_ThreadOptions_stack_size)
* Thread stack size to use (in bytes).
* [`size_t tensorflow::ThreadOptions::guard_size`](#size_t_tensorflow_ThreadOptions_guard_size)
* Guard area size to use near thread stacks to use (in bytes)
##Member Details <a class="md-anchor" id="AUTOGENERATED-member-details"></a>
##Member Details
#### `size_t tensorflow::ThreadOptions::stack_size` <a class="md-anchor" id="size_t_tensorflow_ThreadOptions_stack_size"></a>
#### `size_t tensorflow::ThreadOptions::stack_size` {#size_t_tensorflow_ThreadOptions_stack_size}
Thread stack size to use (in bytes).
#### `size_t tensorflow::ThreadOptions::guard_size` <a class="md-anchor" id="size_t_tensorflow_ThreadOptions_guard_size"></a>
#### `size_t tensorflow::ThreadOptions::guard_size` {#size_t_tensorflow_ThreadOptions_guard_size}
Guard area size to use near thread stacks to use (in bytes)

View File

@ -1,4 +1,4 @@
# TensorFlow C++ Session API reference documentation <a class="md-anchor" id="AUTOGENERATED-tensorflow-c---session-api-reference-documentation"></a>
# TensorFlow C++ Session API reference documentation
TensorFlow's public C++ API includes only the API for executing graphs, as of
version 0.5. To control the execution of a graph from C++:
@ -23,31 +23,31 @@ write the graph to a file.
1. Run the graph with a call to `session->Run()`
## Env <a class="md-anchor" id="AUTOGENERATED-env"></a>
## Env
* [tensorflow::Env](../../api_docs/cc/ClassEnv.md)
* [tensorflow::RandomAccessFile](../../api_docs/cc/ClassRandomAccessFile.md)
* [tensorflow::WritableFile](../../api_docs/cc/ClassWritableFile.md)
* [tensorflow::EnvWrapper](../../api_docs/cc/ClassEnvWrapper.md)
## Session <a class="md-anchor" id="AUTOGENERATED-session"></a>
## Session
* [tensorflow::Session](../../api_docs/cc/ClassSession.md)
* [tensorflow::SessionOptions](../../api_docs/cc/StructSessionOptions.md)
## Status <a class="md-anchor" id="AUTOGENERATED-status"></a>
## Status
* [tensorflow::Status](../../api_docs/cc/ClassStatus.md)
* [tensorflow::Status::State](../../api_docs/cc/StructState.md)
## Tensor <a class="md-anchor" id="AUTOGENERATED-tensor"></a>
## Tensor
* [tensorflow::Tensor](../../api_docs/cc/ClassTensor.md)
* [tensorflow::TensorShape](../../api_docs/cc/ClassTensorShape.md)
* [tensorflow::TensorShapeDim](../../api_docs/cc/StructTensorShapeDim.md)
* [tensorflow::TensorShapeUtils](../../api_docs/cc/ClassTensorShapeUtils.md)
## Thread <a class="md-anchor" id="AUTOGENERATED-thread"></a>
## Thread
* [tensorflow::Thread](../../api_docs/cc/ClassThread.md)
* [tensorflow::ThreadOptions](../../api_docs/cc/StructThreadOptions.md)

View File

@ -1,4 +1,4 @@
# Overview <a class="md-anchor" id="AUTOGENERATED-overview"></a>
# Overview
TensorFlow has APIs available in several languages both for constructing and
executing a TensorFlow graph. The Python API is at present the most complete

View File

@ -1,61 +1,27 @@
<!-- This file is machine generated: DO NOT EDIT! -->
# Tensor Transformations <a class="md-anchor" id="AUTOGENERATED-tensor-transformations"></a>
# Tensor Transformations
Note: Functions taking `Tensor` arguments can also take anything accepted by
[`tf.convert_to_tensor`](../../api_docs/python/framework.md#convert_to_tensor).
<!-- TOC-BEGIN This section is generated by neural network: DO NOT EDIT! -->
## Contents
### [Tensor Transformations](#AUTOGENERATED-tensor-transformations)
* [Casting](#AUTOGENERATED-casting)
* [`tf.string_to_number(string_tensor, out_type=None, name=None)`](#string_to_number)
* [`tf.to_double(x, name='ToDouble')`](#to_double)
* [`tf.to_float(x, name='ToFloat')`](#to_float)
* [`tf.to_bfloat16(x, name='ToBFloat16')`](#to_bfloat16)
* [`tf.to_int32(x, name='ToInt32')`](#to_int32)
* [`tf.to_int64(x, name='ToInt64')`](#to_int64)
* [`tf.cast(x, dtype, name=None)`](#cast)
* [Shapes and Shaping](#AUTOGENERATED-shapes-and-shaping)
* [`tf.shape(input, name=None)`](#shape)
* [`tf.size(input, name=None)`](#size)
* [`tf.rank(input, name=None)`](#rank)
* [`tf.reshape(tensor, shape, name=None)`](#reshape)
* [`tf.squeeze(input, squeeze_dims=None, name=None)`](#squeeze)
* [`tf.expand_dims(input, dim, name=None)`](#expand_dims)
* [Slicing and Joining](#AUTOGENERATED-slicing-and-joining)
* [`tf.slice(input_, begin, size, name=None)`](#slice)
* [`tf.split(split_dim, num_split, value, name='split')`](#split)
* [`tf.tile(input, multiples, name=None)`](#tile)
* [`tf.pad(input, paddings, name=None)`](#pad)
* [`tf.concat(concat_dim, values, name='concat')`](#concat)
* [`tf.pack(values, name='pack')`](#pack)
* [`tf.unpack(value, num=None, name='unpack')`](#unpack)
* [`tf.reverse_sequence(input, seq_lengths, seq_dim, name=None)`](#reverse_sequence)
* [`tf.reverse(tensor, dims, name=None)`](#reverse)
* [`tf.transpose(a, perm=None, name='transpose')`](#transpose)
* [`tf.gather(params, indices, name=None)`](#gather)
* [`tf.dynamic_partition(data, partitions, num_partitions, name=None)`](#dynamic_partition)
* [`tf.dynamic_stitch(indices, data, name=None)`](#dynamic_stitch)
[TOC]
<!-- TOC-END This section was generated by neural network, THANKS FOR READING! -->
## Casting <a class="md-anchor" id="AUTOGENERATED-casting"></a>
## Casting
TensorFlow provides several operations that you can use to cast tensor data
types in your graph.
- - -
### `tf.string_to_number(string_tensor, out_type=None, name=None)` <a class="md-anchor" id="string_to_number"></a>
### `tf.string_to_number(string_tensor, out_type=None, name=None)` {#string_to_number}
Converts each string in the input Tensor to the specified numeric type.
(Note that int32 overflow results in an error while float overflow
results in a rounded value.)
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`string_tensor`</b>: A `Tensor` of type `string`.
@ -63,7 +29,7 @@ results in a rounded value.)
The numeric type to interpret each string in string_tensor as.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor` of type `out_type`.
A Tensor of the same shape as the input string_tensor.
@ -71,21 +37,21 @@ results in a rounded value.)
- - -
### `tf.to_double(x, name='ToDouble')` <a class="md-anchor" id="to_double"></a>
### `tf.to_double(x, name='ToDouble')` {#to_double}
Casts a tensor to type `float64`.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`x`</b>: A `Tensor` or `SparseTensor`.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor` or `SparseTensor` with same shape as `x` with type `float64`.
##### Raises: <a class="md-anchor" id="AUTOGENERATED-raises-"></a>
##### Raises:
* <b>`TypeError`</b>: If `x` cannot be cast to the `float64`.
@ -93,21 +59,21 @@ Casts a tensor to type `float64`.
- - -
### `tf.to_float(x, name='ToFloat')` <a class="md-anchor" id="to_float"></a>
### `tf.to_float(x, name='ToFloat')` {#to_float}
Casts a tensor to type `float32`.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`x`</b>: A `Tensor` or `SparseTensor`.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor` or `SparseTensor` with same shape as `x` with type `float32`.
##### Raises: <a class="md-anchor" id="AUTOGENERATED-raises-"></a>
##### Raises:
* <b>`TypeError`</b>: If `x` cannot be cast to the `float32`.
@ -115,21 +81,21 @@ Casts a tensor to type `float32`.
- - -
### `tf.to_bfloat16(x, name='ToBFloat16')` <a class="md-anchor" id="to_bfloat16"></a>
### `tf.to_bfloat16(x, name='ToBFloat16')` {#to_bfloat16}
Casts a tensor to type `bfloat16`.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`x`</b>: A `Tensor` or `SparseTensor`.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor` or `SparseTensor` with same shape as `x` with type `bfloat16`.
##### Raises: <a class="md-anchor" id="AUTOGENERATED-raises-"></a>
##### Raises:
* <b>`TypeError`</b>: If `x` cannot be cast to the `bfloat16`.
@ -137,21 +103,21 @@ Casts a tensor to type `bfloat16`.
- - -
### `tf.to_int32(x, name='ToInt32')` <a class="md-anchor" id="to_int32"></a>
### `tf.to_int32(x, name='ToInt32')` {#to_int32}
Casts a tensor to type `int32`.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`x`</b>: A `Tensor` or `SparseTensor`.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor` or `SparseTensor` with same shape as `x` with type `int32`.
##### Raises: <a class="md-anchor" id="AUTOGENERATED-raises-"></a>
##### Raises:
* <b>`TypeError`</b>: If `x` cannot be cast to the `int32`.
@ -159,21 +125,21 @@ Casts a tensor to type `int32`.
- - -
### `tf.to_int64(x, name='ToInt64')` <a class="md-anchor" id="to_int64"></a>
### `tf.to_int64(x, name='ToInt64')` {#to_int64}
Casts a tensor to type `int64`.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`x`</b>: A `Tensor` or `SparseTensor`.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor` or `SparseTensor` with same shape as `x` with type `int64`.
##### Raises: <a class="md-anchor" id="AUTOGENERATED-raises-"></a>
##### Raises:
* <b>`TypeError`</b>: If `x` cannot be cast to the `int64`.
@ -181,7 +147,7 @@ Casts a tensor to type `int64`.
- - -
### `tf.cast(x, dtype, name=None)` <a class="md-anchor" id="cast"></a>
### `tf.cast(x, dtype, name=None)` {#cast}
Casts a tensor to a new type.
@ -195,32 +161,32 @@ For example:
tf.cast(a, tf.int32) ==> [1, 2] # dtype=tf.int32
```
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`x`</b>: A `Tensor` or `SparseTensor`.
* <b>`dtype`</b>: The destination type.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor` or `SparseTensor` with same shape as `x`.
##### Raises: <a class="md-anchor" id="AUTOGENERATED-raises-"></a>
##### Raises:
* <b>`TypeError`</b>: If `x` cannot be cast to the `dtype`.
## Shapes and Shaping <a class="md-anchor" id="AUTOGENERATED-shapes-and-shaping"></a>
## Shapes and Shaping
TensorFlow provides several operations that you can use to determine the shape
of a tensor and change the shape of a tensor.
- - -
### `tf.shape(input, name=None)` <a class="md-anchor" id="shape"></a>
### `tf.shape(input, name=None)` {#shape}
Returns the shape of a tensor.
@ -233,20 +199,20 @@ For example:
shape(t) ==> [2, 2, 3]
```
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`input`</b>: A `Tensor`.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor` of type `int32`.
- - -
### `tf.size(input, name=None)` <a class="md-anchor" id="size"></a>
### `tf.size(input, name=None)` {#size}
Returns the size of a tensor.
@ -260,20 +226,20 @@ For example:
size(t) ==> 12
```
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`input`</b>: A `Tensor`.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor` of type `int32`.
- - -
### `tf.rank(input, name=None)` <a class="md-anchor" id="rank"></a>
### `tf.rank(input, name=None)` {#rank}
Returns the rank of a tensor.
@ -291,20 +257,20 @@ rank(t) ==> 3
of a tensor is the number of indices required to uniquely select each element
of the tensor. Rank is also known as "order", "degree", or "ndims."
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`input`</b>: A `Tensor`.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor` of type `int32`.
- - -
### `tf.reshape(tensor, shape, name=None)` <a class="md-anchor" id="reshape"></a>
### `tf.reshape(tensor, shape, name=None)` {#reshape}
Reshapes a tensor.
@ -344,21 +310,21 @@ reshape(t, [2, 4]) ==> [[1, 1, 2, 2]
reshape(t, [-1]) ==> [1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6]
```
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`tensor`</b>: A `Tensor`.
* <b>`shape`</b>: A `Tensor` of type `int32`. Defines the shape of the output tensor.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor`. Has the same type as `tensor`.
- - -
### `tf.squeeze(input, squeeze_dims=None, name=None)` <a class="md-anchor" id="squeeze"></a>
### `tf.squeeze(input, squeeze_dims=None, name=None)` {#squeeze}
Removes dimensions of size 1 from the shape of a tensor.
@ -381,7 +347,7 @@ Or, to remove specific size 1 dimensions:
shape(squeeze(t, [2, 4])) ==> [1, 2, 3, 1]
```
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`input`</b>: A `Tensor`. The `input` to squeeze.
@ -390,7 +356,7 @@ shape(squeeze(t, [2, 4])) ==> [1, 2, 3, 1]
index starts at 0. It is an error to squeeze a dimension that is not 1.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor`. Has the same type as `input`.
Contains the same data as `input`, but has one or more dimensions of
@ -399,7 +365,7 @@ shape(squeeze(t, [2, 4])) ==> [1, 2, 3, 1]
- - -
### `tf.expand_dims(input, dim, name=None)` <a class="md-anchor" id="expand_dims"></a>
### `tf.expand_dims(input, dim, name=None)` {#expand_dims}
Inserts a dimension of 1 into a tensor's shape.
@ -434,7 +400,7 @@ This operation requires that:
This operation is related to `squeeze()`, which removes dimensions of
size 1.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`input`</b>: A `Tensor`.
@ -443,7 +409,7 @@ size 1.
expand the shape of `input`.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor`. Has the same type as `input`.
Contains the same data as `input`, but its shape has an additional
@ -451,14 +417,14 @@ size 1.
## Slicing and Joining <a class="md-anchor" id="AUTOGENERATED-slicing-and-joining"></a>
## Slicing and Joining
TensorFlow provides several operations to slice or extract parts of a tensor,
or join multiple tensors together.
- - -
### `tf.slice(input_, begin, size, name=None)` <a class="md-anchor" id="slice"></a>
### `tf.slice(input_, begin, size, name=None)` {#slice}
Extracts a slice from a tensor.
@ -493,7 +459,7 @@ tf.slice(input, [1, 0, 0], [2, 1, 3]) ==> [[[3, 3, 3]],
[[5, 5, 5]]]
```
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`input_`</b>: A `Tensor`.
@ -501,14 +467,14 @@ tf.slice(input, [1, 0, 0], [2, 1, 3]) ==> [[[3, 3, 3]],
* <b>`size`</b>: An `int32` or `int64` `Tensor`.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor` the same type as `input`.
- - -
### `tf.split(split_dim, num_split, value, name='split')` <a class="md-anchor" id="split"></a>
### `tf.split(split_dim, num_split, value, name='split')` {#split}
Splits a tensor into `num_split` tensors along one dimension.
@ -524,7 +490,7 @@ split0, split1, split2 = tf.split(1, 3, value)
tf.shape(split0) ==> [5, 10]
```
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`split_dim`</b>: A 0-D `int32` `Tensor`. The dimension along which to split.
@ -533,14 +499,14 @@ tf.shape(split0) ==> [5, 10]
* <b>`value`</b>: The `Tensor` to split.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
`num_split` `Tensor` objects resulting from splitting `value`.
- - -
### `tf.tile(input, multiples, name=None)` <a class="md-anchor" id="tile"></a>
### `tf.tile(input, multiples, name=None)` {#tile}
Constructs a tensor by tiling a given tensor.
@ -550,7 +516,7 @@ and the values of `input` are replicated `multiples[i]` times along the 'i'th
dimension. For example, tiling `[a b c d]` by `[2]` produces
`[a b c d a b c d]`.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`input`</b>: A `Tensor`. 1-D or higher.
@ -558,14 +524,14 @@ dimension. For example, tiling `[a b c d]` by `[2]` produces
1-D. Length must be the same as the number of dimensions in `input`
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor`. Has the same type as `input`.
- - -
### `tf.pad(input, paddings, name=None)` <a class="md-anchor" id="pad"></a>
### `tf.pad(input, paddings, name=None)` {#pad}
Pads a tensor with zeros.
@ -593,21 +559,21 @@ pad(t, paddings) ==> [[0, 0, 0, 0, 0]
[0, 0, 0, 0, 0]]
```
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`input`</b>: A `Tensor`.
* <b>`paddings`</b>: A `Tensor` of type `int32`.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor`. Has the same type as `input`.
- - -
### `tf.concat(concat_dim, values, name='concat')` <a class="md-anchor" id="concat"></a>
### `tf.concat(concat_dim, values, name='concat')` {#concat}
Concatenates tensors along one dimension.
@ -641,21 +607,21 @@ tf.shape(tf.concat(0, [t3, t4])) ==> [4, 3]
tf.shape(tf.concat(1, [t3, t4])) ==> [2, 6]
```
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`concat_dim`</b>: 0-D `int32` `Tensor`. Dimension along which to concatenate.
* <b>`values`</b>: A list of `Tensor` objects or a single `Tensor`.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor` resulting from concatenation of the input tensors.
- - -
### `tf.pack(values, name='pack')` <a class="md-anchor" id="pack"></a>
### `tf.pack(values, name='pack')` {#pack}
Packs a list of rank-`R` tensors into one rank-`(R+1)` tensor.
@ -667,13 +633,13 @@ This is the opposite of unpack. The numpy equivalent is
tf.pack([x, y, z]) = np.asarray([x, y, z])
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`values`</b>: A list of `Tensor` objects with the same shape and type.
* <b>`name`</b>: A name for this operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
* <b>`output`</b>: A packed `Tensor` with the same type as `values`.
@ -681,7 +647,7 @@ This is the opposite of unpack. The numpy equivalent is
- - -
### `tf.unpack(value, num=None, name='unpack')` <a class="md-anchor" id="unpack"></a>
### `tf.unpack(value, num=None, name='unpack')` {#unpack}
Unpacks the outer dimension of a rank-`R` tensor into rank-`(R-1)` tensors.
@ -696,7 +662,7 @@ This is the opposite of pack. The numpy equivalent is
tf.unpack(x, n) = list(x)
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`value`</b>: A rank `R > 0` `Tensor` to be unpacked.
@ -704,11 +670,11 @@ This is the opposite of pack. The numpy equivalent is
`None` (the default).
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
The list of `Tensor` objects unpacked from `value`.
##### Raises: <a class="md-anchor" id="AUTOGENERATED-raises-"></a>
##### Raises:
* <b>`ValueError`</b>: If `num` is unspecified and cannot be inferred.
@ -716,7 +682,7 @@ This is the opposite of pack. The numpy equivalent is
- - -
### `tf.reverse_sequence(input, seq_lengths, seq_dim, name=None)` <a class="md-anchor" id="reverse_sequence"></a>
### `tf.reverse_sequence(input, seq_lengths, seq_dim, name=None)` {#reverse_sequence}
Reverses variable length slices in dimension `seq_dim`.
@ -750,7 +716,7 @@ output[2, 3:, :, ...] = input[2, 3:, :, ...]
output[3, 2:, :, ...] = input[3, 2:, :, ...]
```
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`input`</b>: A `Tensor`. The input to reverse.
@ -760,7 +726,7 @@ output[3, 2:, :, ...] = input[3, 2:, :, ...]
* <b>`seq_dim`</b>: An `int`. The dimension which is partially reversed.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor`. Has the same type as `input`.
The partially reversed input. It has the same shape as `input`.
@ -768,7 +734,7 @@ output[3, 2:, :, ...] = input[3, 2:, :, ...]
- - -
### `tf.reverse(tensor, dims, name=None)` <a class="md-anchor" id="reverse"></a>
### `tf.reverse(tensor, dims, name=None)` {#reverse}
Reverses specific dimensions of a tensor.
@ -817,7 +783,7 @@ reverse(t, dims) ==> [[[[8, 9, 10, 11],
[12, 13, 14, 15]]]]
```
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`tensor`</b>: A `Tensor`. Must be one of the following types: `uint8`, `int8`, `int32`, `bool`, `float32`, `float64`.
@ -825,14 +791,14 @@ reverse(t, dims) ==> [[[[8, 9, 10, 11],
* <b>`dims`</b>: A `Tensor` of type `bool`. 1-D. The dimensions to reverse.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor`. Has the same type as `tensor`. The same shape as `tensor`.
- - -
### `tf.transpose(a, perm=None, name='transpose')` <a class="md-anchor" id="transpose"></a>
### `tf.transpose(a, perm=None, name='transpose')` {#transpose}
Transposes `a`. Permutes the dimensions according to `perm`.
@ -870,21 +836,21 @@ tf.transpose(b, perm=[0, 2, 1]) ==> [[[1 4]
[9 12]]]
```
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`a`</b>: A `Tensor`.
* <b>`perm`</b>: A permutation of the dimensions of `a`.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A transposed `Tensor`.
- - -
### `tf.gather(params, indices, name=None)` <a class="md-anchor" id="gather"></a>
### `tf.gather(params, indices, name=None)` {#gather}
Gather slices from `params` according to `indices`.
@ -907,21 +873,21 @@ this operation will permute `params` accordingly.
<img style="width:100%" src="../images/Gather.png" alt>
</div>
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`params`</b>: A `Tensor`.
* <b>`indices`</b>: A `Tensor`. Must be one of the following types: `int32`, `int64`.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor`. Has the same type as `params`.
- - -
### `tf.dynamic_partition(data, partitions, num_partitions, name=None)` <a class="md-anchor" id="dynamic_partition"></a>
### `tf.dynamic_partition(data, partitions, num_partitions, name=None)` {#dynamic_partition}
Partitions `data` into `num_partitions` tensors using indices from `partitions`.
@ -957,7 +923,7 @@ For example:
<img style="width:100%" src="../images/DynamicPartition.png" alt>
</div>
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`data`</b>: A `Tensor`.
@ -967,14 +933,14 @@ For example:
The number of partitions to output.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A list of `num_partitions` `Tensor` objects of the same type as data.
- - -
### `tf.dynamic_stitch(indices, data, name=None)` <a class="md-anchor" id="dynamic_stitch"></a>
### `tf.dynamic_stitch(indices, data, name=None)` {#dynamic_stitch}
Interleave the values from the `data` tensors into a single tensor.
@ -1016,14 +982,14 @@ For example:
<img style="width:100%" src="../images/DynamicStitch.png" alt>
</div>
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`indices`</b>: A list of at least 2 `Tensor` objects of type `int32`.
* <b>`data`</b>: A list with the same number of `Tensor` objects as `indices` of `Tensor` objects of the same type.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor`. Has the same type as `data`.

View File

@ -1,45 +1,18 @@
<!-- This file is machine generated: DO NOT EDIT! -->
# Running Graphs <a class="md-anchor" id="AUTOGENERATED-running-graphs"></a>
<!-- TOC-BEGIN This section is generated by neural network: DO NOT EDIT! -->
## Contents
### [Running Graphs](#AUTOGENERATED-running-graphs)
* [Session management](#AUTOGENERATED-session-management)
* [`class tf.Session`](#Session)
* [`class tf.InteractiveSession`](#InteractiveSession)
* [`tf.get_default_session()`](#get_default_session)
* [Error classes](#AUTOGENERATED-error-classes)
* [`class tf.OpError`](#OpError)
* [`class tf.errors.CancelledError`](#CancelledError)
* [`class tf.errors.UnknownError`](#UnknownError)
* [`class tf.errors.InvalidArgumentError`](#InvalidArgumentError)
* [`class tf.errors.DeadlineExceededError`](#DeadlineExceededError)
* [`class tf.errors.NotFoundError`](#NotFoundError)
* [`class tf.errors.AlreadyExistsError`](#AlreadyExistsError)
* [`class tf.errors.PermissionDeniedError`](#PermissionDeniedError)
* [`class tf.errors.UnauthenticatedError`](#UnauthenticatedError)
* [`class tf.errors.ResourceExhaustedError`](#ResourceExhaustedError)
* [`class tf.errors.FailedPreconditionError`](#FailedPreconditionError)
* [`class tf.errors.AbortedError`](#AbortedError)
* [`class tf.errors.OutOfRangeError`](#OutOfRangeError)
* [`class tf.errors.UnimplementedError`](#UnimplementedError)
* [`class tf.errors.InternalError`](#InternalError)
* [`class tf.errors.UnavailableError`](#UnavailableError)
* [`class tf.errors.DataLossError`](#DataLossError)
<!-- TOC-END This section was generated by neural network, THANKS FOR READING! -->
# Running Graphs
[TOC]
This library contains classes for launching graphs and executing operations.
The [basic usage](../../get_started/index.md#basic-usage) guide has
examples of how a graph is launched in a [`tf.Session`](#Session).
## Session management <a class="md-anchor" id="AUTOGENERATED-session-management"></a>
## Session management
- - -
### `class tf.Session` <a class="md-anchor" id="Session"></a>
### `class tf.Session` {#Session}
A class for running TensorFlow operations.
@ -95,7 +68,7 @@ sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
- - -
#### `tf.Session.__init__(target='', graph=None, config=None)` <a class="md-anchor" id="Session.__init__"></a>
#### `tf.Session.__init__(target='', graph=None, config=None)` {#Session.__init__}
Creates a new TensorFlow session.
@ -107,7 +80,7 @@ but each graph can be used in multiple sessions. In this case, it
is often clearer to pass the graph to be launched explicitly to
the session constructor.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`target`</b>: (Optional.) The execution engine to connect to.
@ -120,7 +93,7 @@ the session constructor.
- - -
#### `tf.Session.run(fetches, feed_dict=None)` <a class="md-anchor" id="Session.run"></a>
#### `tf.Session.run(fetches, feed_dict=None)` {#Session.run}
Runs the operations and evaluates the tensors in `fetches`.
@ -160,7 +133,7 @@ one of the following types:
the value should be a
[`SparseTensorValue`](../../api_docs/python/sparse_ops.md#SparseTensorValue).
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`fetches`</b>: A single graph element, or a list of graph elements
@ -168,12 +141,12 @@ one of the following types:
* <b>`feed_dict`</b>: A dictionary that maps graph elements to values
(described above).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
Either a single value if `fetches` is a single graph element, or
a list of values if `fetches` is a list (described above).
##### Raises: <a class="md-anchor" id="AUTOGENERATED-raises-"></a>
##### Raises:
* <b>`RuntimeError`</b>: If this `Session` is in an invalid state (e.g. has been
@ -185,13 +158,13 @@ one of the following types:
- - -
#### `tf.Session.close()` <a class="md-anchor" id="Session.close"></a>
#### `tf.Session.close()` {#Session.close}
Closes this session.
Calling this method frees all resources associated with the session.
##### Raises: <a class="md-anchor" id="AUTOGENERATED-raises-"></a>
##### Raises:
* <b>`RuntimeError`</b>: If an error occurs while closing the session.
@ -200,14 +173,14 @@ Calling this method frees all resources associated with the session.
- - -
#### `tf.Session.graph` <a class="md-anchor" id="Session.graph"></a>
#### `tf.Session.graph` {#Session.graph}
The graph that was launched in this session.
- - -
#### `tf.Session.as_default()` <a class="md-anchor" id="Session.as_default"></a>
#### `tf.Session.as_default()` {#Session.as_default}
Returns a context manager that makes this object the default session.
@ -254,7 +227,7 @@ create a new thread, and wish to use the default session in that
thread, you must explicitly add a `with sess.as_default():` in that
thread's function.
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A context manager using this session as the default session.
@ -262,7 +235,7 @@ thread's function.
- - -
### `class tf.InteractiveSession` <a class="md-anchor" id="InteractiveSession"></a>
### `class tf.InteractiveSession` {#InteractiveSession}
A TensorFlow `Session` for use in interactive contexts, such as a shell.
@ -303,7 +276,7 @@ with tf.Session():
- - -
#### `tf.InteractiveSession.__init__(target='', graph=None)` <a class="md-anchor" id="InteractiveSession.__init__"></a>
#### `tf.InteractiveSession.__init__(target='', graph=None)` {#InteractiveSession.__init__}
Creates a new interactive TensorFlow session.
@ -315,7 +288,7 @@ but each graph can be used in multiple sessions. In this case, it
is often clearer to pass the graph to be launched explicitly to
the session constructor.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`target`</b>: (Optional.) The execution engine to connect to.
@ -326,7 +299,7 @@ the session constructor.
- - -
#### `tf.InteractiveSession.close()` <a class="md-anchor" id="InteractiveSession.close"></a>
#### `tf.InteractiveSession.close()` {#InteractiveSession.close}
Closes an `InteractiveSession`.
@ -335,7 +308,7 @@ Closes an `InteractiveSession`.
- - -
### `tf.get_default_session()` <a class="md-anchor" id="get_default_session"></a>
### `tf.get_default_session()` {#get_default_session}
Returns the default session for the current thread.
@ -347,17 +320,17 @@ create a new thread, and wish to use the default session in that
thread, you must explicitly add a `with sess.as_default():` in that
thread's function.
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
The default `Session` being used in the current thread.
## Error classes <a class="md-anchor" id="AUTOGENERATED-error-classes"></a>
## Error classes
- - -
### `class tf.OpError` <a class="md-anchor" id="OpError"></a>
### `class tf.OpError` {#OpError}
A generic error that is raised when TensorFlow execution fails.
@ -366,7 +339,7 @@ of `OpError` from the `tf.errors` module.
- - -
#### `tf.OpError.op` <a class="md-anchor" id="OpError.op"></a>
#### `tf.OpError.op` {#OpError.op}
The operation that failed, if known.
@ -377,25 +350,25 @@ will return `None`, and you should instead use the
[`OpError.node_def`](#OpError.node_def) to discover information about the
op.
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
The `Operation` that failed, or None.
- - -
#### `tf.OpError.node_def` <a class="md-anchor" id="OpError.node_def"></a>
#### `tf.OpError.node_def` {#OpError.node_def}
The `NodeDef` proto representing the op that failed.
#### Other Methods <a class="md-anchor" id="AUTOGENERATED-other-methods"></a>
#### Other Methods
- - -
#### `tf.OpError.__init__(node_def, op, message, error_code)` <a class="md-anchor" id="OpError.__init__"></a>
#### `tf.OpError.__init__(node_def, op, message, error_code)` {#OpError.__init__}
Creates a new OpError indicating that a particular op failed.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`node_def`</b>: The graph_pb2.NodeDef proto representing the op that failed.
@ -406,20 +379,20 @@ Creates a new OpError indicating that a particular op failed.
- - -
#### `tf.OpError.error_code` <a class="md-anchor" id="OpError.error_code"></a>
#### `tf.OpError.error_code` {#OpError.error_code}
The integer error code that describes the error.
- - -
#### `tf.OpError.message` <a class="md-anchor" id="OpError.message"></a>
#### `tf.OpError.message` {#OpError.message}
The error message that describes the error.
- - -
### `class tf.errors.CancelledError` <a class="md-anchor" id="CancelledError"></a>
### `class tf.errors.CancelledError` {#CancelledError}
Raised when an operation or step is cancelled.
@ -433,7 +406,7 @@ A step that is running such a long-running operation will fail by raising
- - -
#### `tf.errors.CancelledError.__init__(node_def, op, message)` <a class="md-anchor" id="CancelledError.__init__"></a>
#### `tf.errors.CancelledError.__init__(node_def, op, message)` {#CancelledError.__init__}
Creates a `CancelledError`.
@ -441,7 +414,7 @@ Creates a `CancelledError`.
- - -
### `class tf.errors.UnknownError` <a class="md-anchor" id="UnknownError"></a>
### `class tf.errors.UnknownError` {#UnknownError}
Unknown error.
@ -453,7 +426,7 @@ error.
- - -
#### `tf.errors.UnknownError.__init__(node_def, op, message, error_code=2)` <a class="md-anchor" id="UnknownError.__init__"></a>
#### `tf.errors.UnknownError.__init__(node_def, op, message, error_code=2)` {#UnknownError.__init__}
Creates an `UnknownError`.
@ -461,7 +434,7 @@ Creates an `UnknownError`.
- - -
### `class tf.errors.InvalidArgumentError` <a class="md-anchor" id="InvalidArgumentError"></a>
### `class tf.errors.InvalidArgumentError` {#InvalidArgumentError}
Raised when an operation receives an invalid argument.
@ -475,7 +448,7 @@ tensor.
- - -
#### `tf.errors.InvalidArgumentError.__init__(node_def, op, message)` <a class="md-anchor" id="InvalidArgumentError.__init__"></a>
#### `tf.errors.InvalidArgumentError.__init__(node_def, op, message)` {#InvalidArgumentError.__init__}
Creates an `InvalidArgumentError`.
@ -483,7 +456,7 @@ Creates an `InvalidArgumentError`.
- - -
### `class tf.errors.DeadlineExceededError` <a class="md-anchor" id="DeadlineExceededError"></a>
### `class tf.errors.DeadlineExceededError` {#DeadlineExceededError}
Raised when a deadline expires before an operation could complete.
@ -491,7 +464,7 @@ This exception is not currently used.
- - -
#### `tf.errors.DeadlineExceededError.__init__(node_def, op, message)` <a class="md-anchor" id="DeadlineExceededError.__init__"></a>
#### `tf.errors.DeadlineExceededError.__init__(node_def, op, message)` {#DeadlineExceededError.__init__}
Creates a `DeadlineExceededError`.
@ -499,7 +472,7 @@ Creates a `DeadlineExceededError`.
- - -
### `class tf.errors.NotFoundError` <a class="md-anchor" id="NotFoundError"></a>
### `class tf.errors.NotFoundError` {#NotFoundError}
Raised when a requested entity (e.g., a file or directory) was not found.
@ -510,7 +483,7 @@ does not exist.
- - -
#### `tf.errors.NotFoundError.__init__(node_def, op, message)` <a class="md-anchor" id="NotFoundError.__init__"></a>
#### `tf.errors.NotFoundError.__init__(node_def, op, message)` {#NotFoundError.__init__}
Creates a `NotFoundError`.
@ -518,7 +491,7 @@ Creates a `NotFoundError`.
- - -
### `class tf.errors.AlreadyExistsError` <a class="md-anchor" id="AlreadyExistsError"></a>
### `class tf.errors.AlreadyExistsError` {#AlreadyExistsError}
Raised when an entity that we attempted to create already exists.
@ -529,7 +502,7 @@ existing file was passed.
- - -
#### `tf.errors.AlreadyExistsError.__init__(node_def, op, message)` <a class="md-anchor" id="AlreadyExistsError.__init__"></a>
#### `tf.errors.AlreadyExistsError.__init__(node_def, op, message)` {#AlreadyExistsError.__init__}
Creates an `AlreadyExistsError`.
@ -537,7 +510,7 @@ Creates an `AlreadyExistsError`.
- - -
### `class tf.errors.PermissionDeniedError` <a class="md-anchor" id="PermissionDeniedError"></a>
### `class tf.errors.PermissionDeniedError` {#PermissionDeniedError}
Raised when the caller does not have permission to run an operation.
@ -548,7 +521,7 @@ file for which the user does not have the read file permission.
- - -
#### `tf.errors.PermissionDeniedError.__init__(node_def, op, message)` <a class="md-anchor" id="PermissionDeniedError.__init__"></a>
#### `tf.errors.PermissionDeniedError.__init__(node_def, op, message)` {#PermissionDeniedError.__init__}
Creates a `PermissionDeniedError`.
@ -556,7 +529,7 @@ Creates a `PermissionDeniedError`.
- - -
### `class tf.errors.UnauthenticatedError` <a class="md-anchor" id="UnauthenticatedError"></a>
### `class tf.errors.UnauthenticatedError` {#UnauthenticatedError}
The request does not have valid authentication credentials.
@ -564,7 +537,7 @@ This exception is not currently used.
- - -
#### `tf.errors.UnauthenticatedError.__init__(node_def, op, message)` <a class="md-anchor" id="UnauthenticatedError.__init__"></a>
#### `tf.errors.UnauthenticatedError.__init__(node_def, op, message)` {#UnauthenticatedError.__init__}
Creates an `UnauthenticatedError`.
@ -572,7 +545,7 @@ Creates an `UnauthenticatedError`.
- - -
### `class tf.errors.ResourceExhaustedError` <a class="md-anchor" id="ResourceExhaustedError"></a>
### `class tf.errors.ResourceExhaustedError` {#ResourceExhaustedError}
Some resource has been exhausted.
@ -581,7 +554,7 @@ exhausted, or perhaps the entire file system is out of space.
- - -
#### `tf.errors.ResourceExhaustedError.__init__(node_def, op, message)` <a class="md-anchor" id="ResourceExhaustedError.__init__"></a>
#### `tf.errors.ResourceExhaustedError.__init__(node_def, op, message)` {#ResourceExhaustedError.__init__}
Creates a `ResourceExhaustedError`.
@ -589,7 +562,7 @@ Creates a `ResourceExhaustedError`.
- - -
### `class tf.errors.FailedPreconditionError` <a class="md-anchor" id="FailedPreconditionError"></a>
### `class tf.errors.FailedPreconditionError` {#FailedPreconditionError}
Operation was rejected because the system is not in a state to execute it.
@ -599,7 +572,7 @@ before it has been initialized.
- - -
#### `tf.errors.FailedPreconditionError.__init__(node_def, op, message)` <a class="md-anchor" id="FailedPreconditionError.__init__"></a>
#### `tf.errors.FailedPreconditionError.__init__(node_def, op, message)` {#FailedPreconditionError.__init__}
Creates a `FailedPreconditionError`.
@ -607,7 +580,7 @@ Creates a `FailedPreconditionError`.
- - -
### `class tf.errors.AbortedError` <a class="md-anchor" id="AbortedError"></a>
### `class tf.errors.AbortedError` {#AbortedError}
The operation was aborted, typically due to a concurrent action.
@ -619,7 +592,7 @@ previously ran.
- - -
#### `tf.errors.AbortedError.__init__(node_def, op, message)` <a class="md-anchor" id="AbortedError.__init__"></a>
#### `tf.errors.AbortedError.__init__(node_def, op, message)` {#AbortedError.__init__}
Creates an `AbortedError`.
@ -627,7 +600,7 @@ Creates an `AbortedError`.
- - -
### `class tf.errors.OutOfRangeError` <a class="md-anchor" id="OutOfRangeError"></a>
### `class tf.errors.OutOfRangeError` {#OutOfRangeError}
Raised when an operation executed past the valid range.
@ -639,7 +612,7 @@ operation executes.
- - -
#### `tf.errors.OutOfRangeError.__init__(node_def, op, message)` <a class="md-anchor" id="OutOfRangeError.__init__"></a>
#### `tf.errors.OutOfRangeError.__init__(node_def, op, message)` {#OutOfRangeError.__init__}
Creates an `OutOfRangeError`.
@ -647,7 +620,7 @@ Creates an `OutOfRangeError`.
- - -
### `class tf.errors.UnimplementedError` <a class="md-anchor" id="UnimplementedError"></a>
### `class tf.errors.UnimplementedError` {#UnimplementedError}
Raised when an operation has not been implemented.
@ -659,7 +632,7 @@ because this is not yet supported.
- - -
#### `tf.errors.UnimplementedError.__init__(node_def, op, message)` <a class="md-anchor" id="UnimplementedError.__init__"></a>
#### `tf.errors.UnimplementedError.__init__(node_def, op, message)` {#UnimplementedError.__init__}
Creates an `UnimplementedError`.
@ -667,7 +640,7 @@ Creates an `UnimplementedError`.
- - -
### `class tf.errors.InternalError` <a class="md-anchor" id="InternalError"></a>
### `class tf.errors.InternalError` {#InternalError}
Raised when the system experiences an internal error.
@ -676,7 +649,7 @@ has been broken. Catching this exception is not recommended.
- - -
#### `tf.errors.InternalError.__init__(node_def, op, message)` <a class="md-anchor" id="InternalError.__init__"></a>
#### `tf.errors.InternalError.__init__(node_def, op, message)` {#InternalError.__init__}
Creates an `InternalError`.
@ -684,7 +657,7 @@ Creates an `InternalError`.
- - -
### `class tf.errors.UnavailableError` <a class="md-anchor" id="UnavailableError"></a>
### `class tf.errors.UnavailableError` {#UnavailableError}
Raised when the runtime is currently unavailable.
@ -692,7 +665,7 @@ This exception is not currently used.
- - -
#### `tf.errors.UnavailableError.__init__(node_def, op, message)` <a class="md-anchor" id="UnavailableError.__init__"></a>
#### `tf.errors.UnavailableError.__init__(node_def, op, message)` {#UnavailableError.__init__}
Creates an `UnavailableError`.
@ -700,7 +673,7 @@ Creates an `UnavailableError`.
- - -
### `class tf.errors.DataLossError` <a class="md-anchor" id="DataLossError"></a>
### `class tf.errors.DataLossError` {#DataLossError}
Raised when unrecoverable data loss or corruption is encountered.
@ -710,7 +683,7 @@ operation, if the file is truncated while it is being read.
- - -
#### `tf.errors.DataLossError.__init__(node_def, op, message)` <a class="md-anchor" id="DataLossError.__init__"></a>
#### `tf.errors.DataLossError.__init__(node_def, op, message)` {#DataLossError.__init__}
Creates a `DataLossError`.

View File

@ -1,41 +1,19 @@
<!-- This file is machine generated: DO NOT EDIT! -->
# Constants, Sequences, and Random Values <a class="md-anchor" id="AUTOGENERATED-constants--sequences--and-random-values"></a>
# Constants, Sequences, and Random Values
Note: Functions taking `Tensor` arguments can also take anything accepted by
[`tf.convert_to_tensor`](../../api_docs/python/framework.md#convert_to_tensor).
<!-- TOC-BEGIN This section is generated by neural network: DO NOT EDIT! -->
## Contents
### [Constants, Sequences, and Random Values](#AUTOGENERATED-constants--sequences--and-random-values)
* [Constant Value Tensors](#AUTOGENERATED-constant-value-tensors)
* [`tf.zeros(shape, dtype=tf.float32, name=None)`](#zeros)
* [`tf.zeros_like(tensor, dtype=None, name=None)`](#zeros_like)
* [`tf.ones(shape, dtype=tf.float32, name=None)`](#ones)
* [`tf.ones_like(tensor, dtype=None, name=None)`](#ones_like)
* [`tf.fill(dims, value, name=None)`](#fill)
* [`tf.constant(value, dtype=None, shape=None, name='Const')`](#constant)
* [Sequences](#AUTOGENERATED-sequences)
* [`tf.linspace(start, stop, num, name=None)`](#linspace)
* [`tf.range(start, limit=None, delta=1, name='range')`](#range)
* [Random Tensors](#AUTOGENERATED-random-tensors)
* [Examples:](#AUTOGENERATED-examples-)
* [`tf.random_normal(shape, mean=0.0, stddev=1.0, dtype=tf.float32, seed=None, name=None)`](#random_normal)
* [`tf.truncated_normal(shape, mean=0.0, stddev=1.0, dtype=tf.float32, seed=None, name=None)`](#truncated_normal)
* [`tf.random_uniform(shape, minval=0.0, maxval=1.0, dtype=tf.float32, seed=None, name=None)`](#random_uniform)
* [`tf.random_shuffle(value, seed=None, name=None)`](#random_shuffle)
* [`tf.set_random_seed(seed)`](#set_random_seed)
[TOC]
<!-- TOC-END This section was generated by neural network, THANKS FOR READING! -->
## Constant Value Tensors <a class="md-anchor" id="AUTOGENERATED-constant-value-tensors"></a>
## Constant Value Tensors
TensorFlow provides several operations that you can use to generate constants.
- - -
### `tf.zeros(shape, dtype=tf.float32, name=None)` <a class="md-anchor" id="zeros"></a>
### `tf.zeros(shape, dtype=tf.float32, name=None)` {#zeros}
Creates a tensor with all elements set to zero.
@ -48,21 +26,21 @@ For example:
tf.zeros([3, 4], int32) ==> [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]
```
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`shape`</b>: Either a list of integers, or a 1-D `Tensor` of type `int32`.
* <b>`dtype`</b>: The type of an element in the resulting `Tensor`.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor` with all elements set to zero.
- - -
### `tf.zeros_like(tensor, dtype=None, name=None)` <a class="md-anchor" id="zeros_like"></a>
### `tf.zeros_like(tensor, dtype=None, name=None)` {#zeros_like}
Creates a tensor with all elements set to zero.
@ -77,7 +55,7 @@ For example:
tf.zeros_like(tensor) ==> [[0, 0, 0], [0, 0, 0]]
```
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`tensor`</b>: A `Tensor`.
@ -86,7 +64,7 @@ tf.zeros_like(tensor) ==> [[0, 0, 0], [0, 0, 0]]
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor` with all elements set to zero.
@ -94,7 +72,7 @@ tf.zeros_like(tensor) ==> [[0, 0, 0], [0, 0, 0]]
- - -
### `tf.ones(shape, dtype=tf.float32, name=None)` <a class="md-anchor" id="ones"></a>
### `tf.ones(shape, dtype=tf.float32, name=None)` {#ones}
Creates a tensor with all elements set to 1.
@ -107,21 +85,21 @@ For example:
tf.ones([2, 3], int32) ==> [[1, 1, 1], [1, 1, 1]]
```
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`shape`</b>: Either a list of integers, or a 1-D `Tensor` of type `int32`.
* <b>`dtype`</b>: The type of an element in the resulting `Tensor`.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor` with all elements set to 1.
- - -
### `tf.ones_like(tensor, dtype=None, name=None)` <a class="md-anchor" id="ones_like"></a>
### `tf.ones_like(tensor, dtype=None, name=None)` {#ones_like}
Creates a tensor with all elements set to 1.
@ -136,7 +114,7 @@ For example:
tf.ones_like(tensor) ==> [[1, 1, 1], [1, 1, 1]]
```
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`tensor`</b>: A `Tensor`.
@ -145,7 +123,7 @@ tf.ones_like(tensor) ==> [[1, 1, 1], [1, 1, 1]]
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor` with all elements set to 1.
@ -153,7 +131,7 @@ tf.ones_like(tensor) ==> [[1, 1, 1], [1, 1, 1]]
- - -
### `tf.fill(dims, value, name=None)` <a class="md-anchor" id="fill"></a>
### `tf.fill(dims, value, name=None)` {#fill}
Creates a tensor filled with a scalar value.
@ -168,7 +146,7 @@ fill(dims, 9) ==> [[9, 9, 9]
[9, 9, 9]]
```
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`dims`</b>: A `Tensor` of type `int32`.
@ -176,7 +154,7 @@ fill(dims, 9) ==> [[9, 9, 9]
* <b>`value`</b>: A `Tensor`. 0-D (scalar). Value to fill the returned tensor.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor`. Has the same type as `value`.
@ -184,7 +162,7 @@ fill(dims, 9) ==> [[9, 9, 9]
- - -
### `tf.constant(value, dtype=None, shape=None, name='Const')` <a class="md-anchor" id="constant"></a>
### `tf.constant(value, dtype=None, shape=None, name='Const')` {#constant}
Creates a constant tensor.
@ -217,7 +195,7 @@ Creates a constant tensor.
[-1. -1. -1.]]
```
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`value`</b>: A constant value (or list) of output type `dtype`.
@ -231,17 +209,17 @@ Creates a constant tensor.
* <b>`name`</b>: Optional name for the tensor.
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A Constant Tensor.
## Sequences <a class="md-anchor" id="AUTOGENERATED-sequences"></a>
## Sequences
- - -
### `tf.linspace(start, stop, num, name=None)` <a class="md-anchor" id="linspace"></a>
### `tf.linspace(start, stop, num, name=None)` {#linspace}
Generates values in an interval.
@ -255,7 +233,7 @@ For example:
tf.linspace(10.0, 12.0, 3, name="linspace") => [ 10.0 11.0 12.0]
```
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`start`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`.
@ -265,7 +243,7 @@ tf.linspace(10.0, 12.0, 3, name="linspace") => [ 10.0 11.0 12.0]
* <b>`num`</b>: A `Tensor` of type `int32`. Number of values to generate.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor`. Has the same type as `start`. 1-D. The generated values.
@ -273,7 +251,7 @@ tf.linspace(10.0, 12.0, 3, name="linspace") => [ 10.0 11.0 12.0]
- - -
### `tf.range(start, limit=None, delta=1, name='range')` <a class="md-anchor" id="range"></a>
### `tf.range(start, limit=None, delta=1, name='range')` {#range}
Creates a sequence of integers.
@ -295,7 +273,7 @@ tf.range(start, limit, delta) ==> [3, 6, 9, 12, 15]
tf.range(limit) ==> [0, 1, 2, 3, 4]
```
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`start`</b>: A 0-D (scalar) of type `int32`. First entry in sequence.
@ -306,13 +284,13 @@ tf.range(limit) ==> [0, 1, 2, 3, 4]
Number that increments `start`.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
An 1-D `int32` `Tensor`.
## Random Tensors <a class="md-anchor" id="AUTOGENERATED-random-tensors"></a>
## Random Tensors
TensorFlow has several ops that create random tensors with different
distributions. The random ops are stateful, and create new random values each
@ -328,7 +306,7 @@ See [`set_random_seed`](../../api_docs/python/constant_op.md#set_random_seed)
for details on the interaction between operation-level and graph-level random
seeds.
### Examples: <a class="md-anchor" id="AUTOGENERATED-examples-"></a>
### Examples:
```python
# Create a tensor of shape [2, 3] consisting of random normal values, with mean
@ -368,11 +346,11 @@ print sess.run(var)
- - -
### `tf.random_normal(shape, mean=0.0, stddev=1.0, dtype=tf.float32, seed=None, name=None)` <a class="md-anchor" id="random_normal"></a>
### `tf.random_normal(shape, mean=0.0, stddev=1.0, dtype=tf.float32, seed=None, name=None)` {#random_normal}
Outputs random values from a normal distribution.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`shape`</b>: A 1-D integer Tensor or Python array. The shape of the output tensor.
@ -387,14 +365,14 @@ Outputs random values from a normal distribution.
for behavior.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A tensor of the specified shape filled with random normal values.
- - -
### `tf.truncated_normal(shape, mean=0.0, stddev=1.0, dtype=tf.float32, seed=None, name=None)` <a class="md-anchor" id="truncated_normal"></a>
### `tf.truncated_normal(shape, mean=0.0, stddev=1.0, dtype=tf.float32, seed=None, name=None)` {#truncated_normal}
Outputs random values from a truncated normal distribution.
@ -402,7 +380,7 @@ The generated values follow a normal distribution with specified mean and
standard deviation, except that values whose magnitude is more than 2 standard
deviations from the mean are dropped and re-picked.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`shape`</b>: A 1-D integer Tensor or Python array. The shape of the output tensor.
@ -417,14 +395,14 @@ deviations from the mean are dropped and re-picked.
for behavior.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A tensor of the specified shape filled with random truncated normal values.
- - -
### `tf.random_uniform(shape, minval=0.0, maxval=1.0, dtype=tf.float32, seed=None, name=None)` <a class="md-anchor" id="random_uniform"></a>
### `tf.random_uniform(shape, minval=0.0, maxval=1.0, dtype=tf.float32, seed=None, name=None)` {#random_uniform}
Outputs random values from a uniform distribution.
@ -432,7 +410,7 @@ The generated values follow a uniform distribution in the range
`[minval, maxval)`. The lower bound `minval` is included in the range, while
the upper bound `maxval` is excluded.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`shape`</b>: A 1-D integer Tensor or Python array. The shape of the output tensor.
@ -447,14 +425,14 @@ the upper bound `maxval` is excluded.
for behavior.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A tensor of the specified shape filled with random uniform values.
- - -
### `tf.random_shuffle(value, seed=None, name=None)` <a class="md-anchor" id="random_shuffle"></a>
### `tf.random_shuffle(value, seed=None, name=None)` {#random_shuffle}
Randomly shuffles a tensor along its first dimension.
@ -468,7 +446,7 @@ to one and only one `output[i]`. For example, a mapping that might occur for a
[5, 6]] [3, 4]]
```
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`value`</b>: A Tensor to be shuffled.
@ -478,7 +456,7 @@ to one and only one `output[i]`. For example, a mapping that might occur for a
for behavior.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A tensor of same shape and type as `value`, shuffled along its first
dimension.
@ -486,7 +464,7 @@ to one and only one `output[i]`. For example, a mapping that might occur for a
- - -
### `tf.set_random_seed(seed)` <a class="md-anchor" id="set_random_seed"></a>
### `tf.set_random_seed(seed)` {#set_random_seed}
Sets the graph-level random seed.
@ -579,7 +557,7 @@ with tf.Session() as sess2:
print sess2.run(b) # generates 'B2'
```
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`seed`</b>: integer.

View File

@ -1,73 +1,37 @@
<!-- This file is machine generated: DO NOT EDIT! -->
# Control Flow <a class="md-anchor" id="AUTOGENERATED-control-flow"></a>
# Control Flow
Note: Functions taking `Tensor` arguments can also take anything accepted by
[`tf.convert_to_tensor`](../../api_docs/python/framework.md#convert_to_tensor).
<!-- TOC-BEGIN This section is generated by neural network: DO NOT EDIT! -->
## Contents
### [Control Flow](#AUTOGENERATED-control-flow)
* [Control Flow Operations](#AUTOGENERATED-control-flow-operations)
* [`tf.identity(input, name=None)`](#identity)
* [`tf.tuple(tensors, name=None, control_inputs=None)`](#tuple)
* [`tf.group(*inputs, **kwargs)`](#group)
* [`tf.no_op(name=None)`](#no_op)
* [`tf.count_up_to(ref, limit, name=None)`](#count_up_to)
* [Logical Operators](#AUTOGENERATED-logical-operators)
* [`tf.logical_and(x, y, name=None)`](#logical_and)
* [`tf.logical_not(x, name=None)`](#logical_not)
* [`tf.logical_or(x, y, name=None)`](#logical_or)
* [`tf.logical_xor(x, y, name='LogicalXor')`](#logical_xor)
* [Comparison Operators](#AUTOGENERATED-comparison-operators)
* [`tf.equal(x, y, name=None)`](#equal)
* [`tf.not_equal(x, y, name=None)`](#not_equal)
* [`tf.less(x, y, name=None)`](#less)
* [`tf.less_equal(x, y, name=None)`](#less_equal)
* [`tf.greater(x, y, name=None)`](#greater)
* [`tf.greater_equal(x, y, name=None)`](#greater_equal)
* [`tf.select(condition, t, e, name=None)`](#select)
* [`tf.where(input, name=None)`](#where)
* [Debugging Operations](#AUTOGENERATED-debugging-operations)
* [`tf.is_finite(x, name=None)`](#is_finite)
* [`tf.is_inf(x, name=None)`](#is_inf)
* [`tf.is_nan(x, name=None)`](#is_nan)
* [`tf.verify_tensor_all_finite(t, msg, name=None)`](#verify_tensor_all_finite)
* [`tf.check_numerics(tensor, message, name=None)`](#check_numerics)
* [`tf.add_check_numerics_ops()`](#add_check_numerics_ops)
* [`tf.Assert(condition, data, summarize=None, name=None)`](#Assert)
* [`tf.Print(input_, data, message=None, first_n=None, summarize=None, name=None)`](#Print)
* [Other Functions and Classes](#AUTOGENERATED-other-functions-and-classes)
* [`class tf.xrange`](#xrange)
[TOC]
<!-- TOC-END This section was generated by neural network, THANKS FOR READING! -->
## Control Flow Operations <a class="md-anchor" id="AUTOGENERATED-control-flow-operations"></a>
## Control Flow Operations
TensorFlow provides several operations and classes that you can use to control
the execution of operations and add conditional dependencies to your graph.
- - -
### `tf.identity(input, name=None)` <a class="md-anchor" id="identity"></a>
### `tf.identity(input, name=None)` {#identity}
Return a tensor with the same shape and contents as the input tensor or value.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`input`</b>: A `Tensor`.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor`. Has the same type as `input`.
- - -
### `tf.tuple(tensors, name=None, control_inputs=None)` <a class="md-anchor" id="tuple"></a>
### `tf.tuple(tensors, name=None, control_inputs=None)` {#tuple}
Group tensors together.
@ -85,18 +49,18 @@ are done.
See also `group` and `with_dependencies`.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`tensors`</b>: A list of `Tensor`s or `IndexedSlices`, some entries can be `None`.
* <b>`name`</b>: (optional) A name to use as a `name_scope` for the operation.
* <b>`control_inputs`</b>: List of additional ops to finish before returning.
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
Same as `tensors`.
##### Raises: <a class="md-anchor" id="AUTOGENERATED-raises-"></a>
##### Raises:
* <b>`ValueError`</b>: If `tensors` does not contain any `Tensor` or `IndexedSlices`.
@ -104,7 +68,7 @@ See also `group` and `with_dependencies`.
- - -
### `tf.group(*inputs, **kwargs)` <a class="md-anchor" id="group"></a>
### `tf.group(*inputs, **kwargs)` {#group}
Create an op that groups multiple operations.
@ -113,18 +77,18 @@ output.
See also `tuple` and `with_dependencies`.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`*inputs`</b>: One or more tensors to group.
* <b>`**kwargs`</b>: Optional parameters to pass when constructing the NodeDef.
* <b>`name`</b>: A name for this operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
An Operation that executes all its inputs.
##### Raises: <a class="md-anchor" id="AUTOGENERATED-raises-"></a>
##### Raises:
* <b>`ValueError`</b>: If an unknown keyword argument is provided, or if there are
@ -133,30 +97,30 @@ See also `tuple` and `with_dependencies`.
- - -
### `tf.no_op(name=None)` <a class="md-anchor" id="no_op"></a>
### `tf.no_op(name=None)` {#no_op}
Does nothing. Only useful as a placeholder for control edges.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
The created Operation.
- - -
### `tf.count_up_to(ref, limit, name=None)` <a class="md-anchor" id="count_up_to"></a>
### `tf.count_up_to(ref, limit, name=None)` {#count_up_to}
Increments 'ref' until it reaches 'limit'.
This operation outputs "ref" after the update is done. This makes it
easier to chain operations that need to use the updated value.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`ref`</b>: A mutable `Tensor`. Must be one of the following types: `int32`, `int64`.
@ -166,7 +130,7 @@ easier to chain operations that need to use the updated value.
'OutOfRange' error.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor`. Has the same type as `ref`.
A copy of the input before increment. If nothing else modifies the
@ -174,188 +138,188 @@ easier to chain operations that need to use the updated value.
## Logical Operators <a class="md-anchor" id="AUTOGENERATED-logical-operators"></a>
## Logical Operators
TensorFlow provides several operations that you can use to add logical operators
to your graph.
- - -
### `tf.logical_and(x, y, name=None)` <a class="md-anchor" id="logical_and"></a>
### `tf.logical_and(x, y, name=None)` {#logical_and}
Returns the truth value of x AND y element-wise.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`x`</b>: A `Tensor` of type `bool`.
* <b>`y`</b>: A `Tensor` of type `bool`.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor` of type `bool`.
- - -
### `tf.logical_not(x, name=None)` <a class="md-anchor" id="logical_not"></a>
### `tf.logical_not(x, name=None)` {#logical_not}
Returns the truth value of NOT x element-wise.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`x`</b>: A `Tensor` of type `bool`.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor` of type `bool`.
- - -
### `tf.logical_or(x, y, name=None)` <a class="md-anchor" id="logical_or"></a>
### `tf.logical_or(x, y, name=None)` {#logical_or}
Returns the truth value of x OR y element-wise.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`x`</b>: A `Tensor` of type `bool`.
* <b>`y`</b>: A `Tensor` of type `bool`.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor` of type `bool`.
- - -
### `tf.logical_xor(x, y, name='LogicalXor')` <a class="md-anchor" id="logical_xor"></a>
### `tf.logical_xor(x, y, name='LogicalXor')` {#logical_xor}
x ^ y = (x | y) & ~(x & y).
## Comparison Operators <a class="md-anchor" id="AUTOGENERATED-comparison-operators"></a>
## Comparison Operators
TensorFlow provides several operations that you can use to add comparison
operators to your graph.
- - -
### `tf.equal(x, y, name=None)` <a class="md-anchor" id="equal"></a>
### `tf.equal(x, y, name=None)` {#equal}
Returns the truth value of (x == y) element-wise.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`x`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `int64`, `complex64`, `quint8`, `qint8`, `qint32`.
* <b>`y`</b>: A `Tensor`. Must have the same type as `x`.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor` of type `bool`.
- - -
### `tf.not_equal(x, y, name=None)` <a class="md-anchor" id="not_equal"></a>
### `tf.not_equal(x, y, name=None)` {#not_equal}
Returns the truth value of (x != y) element-wise.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`x`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `int64`, `complex64`, `quint8`, `qint8`, `qint32`.
* <b>`y`</b>: A `Tensor`. Must have the same type as `x`.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor` of type `bool`.
- - -
### `tf.less(x, y, name=None)` <a class="md-anchor" id="less"></a>
### `tf.less(x, y, name=None)` {#less}
Returns the truth value of (x < y) element-wise.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`x`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `int64`.
* <b>`y`</b>: A `Tensor`. Must have the same type as `x`.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor` of type `bool`.
- - -
### `tf.less_equal(x, y, name=None)` <a class="md-anchor" id="less_equal"></a>
### `tf.less_equal(x, y, name=None)` {#less_equal}
Returns the truth value of (x <= y) element-wise.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`x`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `int64`.
* <b>`y`</b>: A `Tensor`. Must have the same type as `x`.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor` of type `bool`.
- - -
### `tf.greater(x, y, name=None)` <a class="md-anchor" id="greater"></a>
### `tf.greater(x, y, name=None)` {#greater}
Returns the truth value of (x > y) element-wise.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`x`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `int64`.
* <b>`y`</b>: A `Tensor`. Must have the same type as `x`.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor` of type `bool`.
- - -
### `tf.greater_equal(x, y, name=None)` <a class="md-anchor" id="greater_equal"></a>
### `tf.greater_equal(x, y, name=None)` {#greater_equal}
Returns the truth value of (x >= y) element-wise.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`x`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `int64`.
* <b>`y`</b>: A `Tensor`. Must have the same type as `x`.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor` of type `bool`.
- - -
### `tf.select(condition, t, e, name=None)` <a class="md-anchor" id="select"></a>
### `tf.select(condition, t, e, name=None)` {#select}
Selects elements from `t` or `e`, depending on `condition`.
@ -378,7 +342,7 @@ select(condition, t, e) ==> [[1, 2],
[1, 2]]
```
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`condition`</b>: A `Tensor` of type `bool`.
@ -386,14 +350,14 @@ select(condition, t, e) ==> [[1, 2],
* <b>`e`</b>: A `Tensor` with the same type and shape as `t`.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor` with the same type and shape as `t` and `e`.
- - -
### `tf.where(input, name=None)` <a class="md-anchor" id="where"></a>
### `tf.where(input, name=None)` {#where}
Returns locations of true values in a boolean tensor.
@ -429,116 +393,116 @@ where(input) ==> [[0, 0, 0],
[2, 1, 1]]
```
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`input`</b>: A `Tensor` of type `bool`.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor` of type `int64`.
## Debugging Operations <a class="md-anchor" id="AUTOGENERATED-debugging-operations"></a>
## Debugging Operations
TensorFlow provides several operations that you can use to validate values and
debug your graph.
- - -
### `tf.is_finite(x, name=None)` <a class="md-anchor" id="is_finite"></a>
### `tf.is_finite(x, name=None)` {#is_finite}
Returns which elements of x are finite.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`x`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor` of type `bool`.
- - -
### `tf.is_inf(x, name=None)` <a class="md-anchor" id="is_inf"></a>
### `tf.is_inf(x, name=None)` {#is_inf}
Returns which elements of x are Inf.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`x`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor` of type `bool`.
- - -
### `tf.is_nan(x, name=None)` <a class="md-anchor" id="is_nan"></a>
### `tf.is_nan(x, name=None)` {#is_nan}
Returns which elements of x are NaN.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`x`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor` of type `bool`.
- - -
### `tf.verify_tensor_all_finite(t, msg, name=None)` <a class="md-anchor" id="verify_tensor_all_finite"></a>
### `tf.verify_tensor_all_finite(t, msg, name=None)` {#verify_tensor_all_finite}
Assert that the tensor does not contain any NaN's or Inf's.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`t`</b>: Tensor to check.
* <b>`msg`</b>: Message to log on failure.
* <b>`name`</b>: A name for this operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
Same tensor as `t`.
- - -
### `tf.check_numerics(tensor, message, name=None)` <a class="md-anchor" id="check_numerics"></a>
### `tf.check_numerics(tensor, message, name=None)` {#check_numerics}
Checks a tensor for NaN and Inf values.
When run, reports an `InvalidArgument` error if `tensor` has any values
that are not a number (NaN) or infinity (Inf). Otherwise, passes `tensor` as-is.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`tensor`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`.
* <b>`message`</b>: A `string`. Prefix of the error message.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor`. Has the same type as `tensor`.
- - -
### `tf.add_check_numerics_ops()` <a class="md-anchor" id="add_check_numerics_ops"></a>
### `tf.add_check_numerics_ops()` {#add_check_numerics_ops}
Connect a check_numerics to every floating point tensor.
@ -547,21 +511,21 @@ tensor in the graph. For all ops in the graph, the `check_numerics` op for
all of its (`float` or `double`) inputs is guaranteed to run before the
`check_numerics` op on any of its outputs.
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `group` op depending on all `check_numerics` ops added.
- - -
### `tf.Assert(condition, data, summarize=None, name=None)` <a class="md-anchor" id="Assert"></a>
### `tf.Assert(condition, data, summarize=None, name=None)` {#Assert}
Asserts that the given condition is true.
If `condition` evaluates to false, print the list of tensors in `data`.
`summarize` determines how many entries of the tensors to print.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`condition`</b>: The condition to evaluate.
@ -572,14 +536,14 @@ If `condition` evaluates to false, print the list of tensors in `data`.
- - -
### `tf.Print(input_, data, message=None, first_n=None, summarize=None, name=None)` <a class="md-anchor" id="Print"></a>
### `tf.Print(input_, data, message=None, first_n=None, summarize=None, name=None)` {#Print}
Prints a list of tensors.
This is an identity op with the side effect of printing `data` when
evaluating.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`input_`</b>: A tensor passed through this op.
@ -590,16 +554,16 @@ evaluating.
* <b>`summarize`</b>: Only print this many entries of each tensor.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
Same tensor as `input_`.
## Other Functions and Classes <a class="md-anchor" id="AUTOGENERATED-other-functions-and-classes"></a>
## Other Functions and Classes
- - -
### `class tf.xrange` <a class="md-anchor" id="xrange"></a>
### `class tf.xrange` {#xrange}
xrange(stop) -> xrange object
xrange(start, stop[, step]) -> xrange object

File diff suppressed because it is too large Load Diff

View File

@ -1,47 +1,13 @@
<!-- This file is machine generated: DO NOT EDIT! -->
# Images <a class="md-anchor" id="AUTOGENERATED-images"></a>
# Images
Note: Functions taking `Tensor` arguments can also take anything accepted by
[`tf.convert_to_tensor`](../../api_docs/python/framework.md#convert_to_tensor).
<!-- TOC-BEGIN This section is generated by neural network: DO NOT EDIT! -->
## Contents
### [Images](#AUTOGENERATED-images)
* [Encoding and Decoding](#AUTOGENERATED-encoding-and-decoding)
* [`tf.image.decode_jpeg(contents, channels=None, ratio=None, fancy_upscaling=None, try_recover_truncated=None, acceptable_fraction=None, name=None)`](#decode_jpeg)
* [`tf.image.encode_jpeg(image, format=None, quality=None, progressive=None, optimize_size=None, chroma_downsampling=None, density_unit=None, x_density=None, y_density=None, xmp_metadata=None, name=None)`](#encode_jpeg)
* [`tf.image.decode_png(contents, channels=None, name=None)`](#decode_png)
* [`tf.image.encode_png(image, compression=None, name=None)`](#encode_png)
* [Resizing](#AUTOGENERATED-resizing)
* [`tf.image.resize_images(images, new_height, new_width, method=0)`](#resize_images)
* [`tf.image.resize_area(images, size, name=None)`](#resize_area)
* [`tf.image.resize_bicubic(images, size, name=None)`](#resize_bicubic)
* [`tf.image.resize_bilinear(images, size, name=None)`](#resize_bilinear)
* [`tf.image.resize_nearest_neighbor(images, size, name=None)`](#resize_nearest_neighbor)
* [Cropping](#AUTOGENERATED-cropping)
* [`tf.image.resize_image_with_crop_or_pad(image, target_height, target_width)`](#resize_image_with_crop_or_pad)
* [`tf.image.pad_to_bounding_box(image, offset_height, offset_width, target_height, target_width)`](#pad_to_bounding_box)
* [`tf.image.crop_to_bounding_box(image, offset_height, offset_width, target_height, target_width)`](#crop_to_bounding_box)
* [`tf.image.random_crop(image, size, seed=None, name=None)`](#random_crop)
* [`tf.image.extract_glimpse(input, size, offsets, centered=None, normalized=None, uniform_noise=None, name=None)`](#extract_glimpse)
* [Flipping and Transposing](#AUTOGENERATED-flipping-and-transposing)
* [`tf.image.flip_up_down(image)`](#flip_up_down)
* [`tf.image.random_flip_up_down(image, seed=None)`](#random_flip_up_down)
* [`tf.image.flip_left_right(image)`](#flip_left_right)
* [`tf.image.random_flip_left_right(image, seed=None)`](#random_flip_left_right)
* [`tf.image.transpose_image(image)`](#transpose_image)
* [Image Adjustments](#AUTOGENERATED-image-adjustments)
* [`tf.image.adjust_brightness(image, delta, min_value=None, max_value=None)`](#adjust_brightness)
* [`tf.image.random_brightness(image, max_delta, seed=None)`](#random_brightness)
* [`tf.image.adjust_contrast(images, contrast_factor, min_value=None, max_value=None)`](#adjust_contrast)
* [`tf.image.random_contrast(image, lower, upper, seed=None)`](#random_contrast)
* [`tf.image.per_image_whitening(image)`](#per_image_whitening)
[TOC]
<!-- TOC-END This section was generated by neural network, THANKS FOR READING! -->
## Encoding and Decoding <a class="md-anchor" id="AUTOGENERATED-encoding-and-decoding"></a>
## Encoding and Decoding
TensorFlow provides Ops to decode and encode JPEG and PNG formats. Encoded
images are represented by scalar string Tensors, decoded images by 3-D uint8
@ -56,7 +22,7 @@ presently only support RGB, HSV, and GrayScale.
- - -
### `tf.image.decode_jpeg(contents, channels=None, ratio=None, fancy_upscaling=None, try_recover_truncated=None, acceptable_fraction=None, name=None)` <a class="md-anchor" id="decode_jpeg"></a>
### `tf.image.decode_jpeg(contents, channels=None, ratio=None, fancy_upscaling=None, try_recover_truncated=None, acceptable_fraction=None, name=None)` {#decode_jpeg}
Decode a JPEG-encoded image to a uint8 tensor.
@ -76,7 +42,7 @@ The attr `ratio` allows downscaling the image by an integer factor during
decoding. Allowed values are: 1, 2, 4, and 8. This is much faster than
downscaling the image later.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`contents`</b>: A `Tensor` of type `string`. 0-D. The JPEG-encoded image.
@ -93,14 +59,14 @@ downscaling the image later.
input is accepted.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor` of type `uint8`. 3-D with shape `[height, width, channels]`..
- - -
### `tf.image.encode_jpeg(image, format=None, quality=None, progressive=None, optimize_size=None, chroma_downsampling=None, density_unit=None, x_density=None, y_density=None, xmp_metadata=None, name=None)` <a class="md-anchor" id="encode_jpeg"></a>
### `tf.image.encode_jpeg(image, format=None, quality=None, progressive=None, optimize_size=None, chroma_downsampling=None, density_unit=None, x_density=None, y_density=None, xmp_metadata=None, name=None)` {#encode_jpeg}
JPEG-encode an image.
@ -121,7 +87,7 @@ in function of the number of channels in `image`:
* 1: Output a grayscale image.
* 3: Output an RGB image.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`image`</b>: A `Tensor` of type `uint8`.
@ -147,7 +113,7 @@ in function of the number of channels in `image`:
If not empty, embed this XMP metadata in the image header.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor` of type `string`. 0-D. JPEG-encoded image.
@ -155,7 +121,7 @@ in function of the number of channels in `image`:
- - -
### `tf.image.decode_png(contents, channels=None, name=None)` <a class="md-anchor" id="decode_png"></a>
### `tf.image.decode_png(contents, channels=None, name=None)` {#decode_png}
Decode a PNG-encoded image to a uint8 tensor.
@ -172,7 +138,7 @@ Accepted values are:
If needed, the PNG-encoded image is transformed to match the requested number
of color channels.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`contents`</b>: A `Tensor` of type `string`. 0-D. The PNG-encoded image.
@ -180,14 +146,14 @@ of color channels.
Number of color channels for the decoded image.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor` of type `uint8`. 3-D with shape `[height, width, channels]`.
- - -
### `tf.image.encode_png(image, compression=None, name=None)` <a class="md-anchor" id="encode_png"></a>
### `tf.image.encode_png(image, compression=None, name=None)` {#encode_png}
PNG-encode an image.
@ -202,7 +168,7 @@ The ZLIB compression level, `compression`, can be -1 for the PNG-encoder
default or a value from 0 to 9. 9 is the highest compression level, generating
the smallest output, but is slower.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`image`</b>: A `Tensor` of type `uint8`.
@ -210,13 +176,13 @@ the smallest output, but is slower.
* <b>`compression`</b>: An optional `int`. Defaults to `-1`. Compression level.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor` of type `string`. 0-D. PNG-encoded image.
## Resizing <a class="md-anchor" id="AUTOGENERATED-resizing"></a>
## Resizing
The resizing Ops accept input images as tensors of several types. They always
output resized images as float32 tensors.
@ -244,7 +210,7 @@ images from the Queue.</i>
- - -
### `tf.image.resize_images(images, new_height, new_width, method=0)` <a class="md-anchor" id="resize_images"></a>
### `tf.image.resize_images(images, new_height, new_width, method=0)` {#resize_images}
Resize `images` to `new_width`, `new_height` using the specified `method`.
@ -262,7 +228,7 @@ the same as `new_width`, `new_height`. To avoid distortions see
(https://en.wikipedia.org/wiki/Bicubic_interpolation)
* <b>ResizeMethod.AREA</b>: Area interpolation.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`images`</b>: 4-D Tensor of shape `[batch, height, width, channels]` or
@ -271,14 +237,14 @@ the same as `new_width`, `new_height`. To avoid distortions see
* <b>`new_width`</b>: integer.
* <b>`method`</b>: ResizeMethod. Defaults to `ResizeMethod.BILINEAR`.
##### Raises: <a class="md-anchor" id="AUTOGENERATED-raises-"></a>
##### Raises:
* <b>`ValueError`</b>: if the shape of `images` is incompatible with the
shape arguments to this function
* <b>`ValueError`</b>: if an unsupported resize method is specified.
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
If `images` was 4-D, a 4-D float Tensor of shape
`[batch, new_height, new_width, channels]`.
@ -289,13 +255,13 @@ the same as `new_width`, `new_height`. To avoid distortions see
- - -
### `tf.image.resize_area(images, size, name=None)` <a class="md-anchor" id="resize_area"></a>
### `tf.image.resize_area(images, size, name=None)` {#resize_area}
Resize `images` to `size` using area interpolation.
Input images can be of different types but output images are always float.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`images`</b>: A `Tensor`. Must be one of the following types: `uint8`, `int8`, `int32`, `float32`, `float64`.
@ -304,7 +270,7 @@ Input images can be of different types but output images are always float.
new size for the images.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor` of type `float32`. 4-D with shape
`[batch, new_height, new_width, channels]`.
@ -312,13 +278,13 @@ Input images can be of different types but output images are always float.
- - -
### `tf.image.resize_bicubic(images, size, name=None)` <a class="md-anchor" id="resize_bicubic"></a>
### `tf.image.resize_bicubic(images, size, name=None)` {#resize_bicubic}
Resize `images` to `size` using bicubic interpolation.
Input images can be of different types but output images are always float.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`images`</b>: A `Tensor`. Must be one of the following types: `uint8`, `int8`, `int32`, `float32`, `float64`.
@ -327,7 +293,7 @@ Input images can be of different types but output images are always float.
new size for the images.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor` of type `float32`. 4-D with shape
`[batch, new_height, new_width, channels]`.
@ -335,13 +301,13 @@ Input images can be of different types but output images are always float.
- - -
### `tf.image.resize_bilinear(images, size, name=None)` <a class="md-anchor" id="resize_bilinear"></a>
### `tf.image.resize_bilinear(images, size, name=None)` {#resize_bilinear}
Resize `images` to `size` using bilinear interpolation.
Input images can be of different types but output images are always float.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`images`</b>: A `Tensor`. Must be one of the following types: `uint8`, `int8`, `int32`, `float32`, `float64`.
@ -350,7 +316,7 @@ Input images can be of different types but output images are always float.
new size for the images.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor` of type `float32`. 4-D with shape
`[batch, new_height, new_width, channels]`.
@ -358,13 +324,13 @@ Input images can be of different types but output images are always float.
- - -
### `tf.image.resize_nearest_neighbor(images, size, name=None)` <a class="md-anchor" id="resize_nearest_neighbor"></a>
### `tf.image.resize_nearest_neighbor(images, size, name=None)` {#resize_nearest_neighbor}
Resize `images` to `size` using nearest neighbor interpolation.
Input images can be of different types but output images are always float.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`images`</b>: A `Tensor`. Must be one of the following types: `uint8`, `int8`, `int32`, `float32`, `float64`.
@ -373,7 +339,7 @@ Input images can be of different types but output images are always float.
new size for the images.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor`. Has the same type as `images`. 4-D with shape
`[batch, new_height, new_width, channels]`.
@ -381,11 +347,11 @@ Input images can be of different types but output images are always float.
## Cropping <a class="md-anchor" id="AUTOGENERATED-cropping"></a>
## Cropping
- - -
### `tf.image.resize_image_with_crop_or_pad(image, target_height, target_width)` <a class="md-anchor" id="resize_image_with_crop_or_pad"></a>
### `tf.image.resize_image_with_crop_or_pad(image, target_height, target_width)` {#resize_image_with_crop_or_pad}
Crops and/or pads an image to a target width and height.
@ -398,19 +364,19 @@ If `width` or `height` is smaller than the specified `target_width` or
`target_height` respectively, this op centrally pads with 0 along that
dimension.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`image`</b>: 3-D tensor of shape [height, width, channels]
* <b>`target_height`</b>: Target height.
* <b>`target_width`</b>: Target width.
##### Raises: <a class="md-anchor" id="AUTOGENERATED-raises-"></a>
##### Raises:
* <b>`ValueError`</b>: if `target_height` or `target_width` are zero or negative.
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
Cropped and/or padded image of shape
`[target_height, target_width, channels]`
@ -419,7 +385,7 @@ dimension.
- - -
### `tf.image.pad_to_bounding_box(image, offset_height, offset_width, target_height, target_width)` <a class="md-anchor" id="pad_to_bounding_box"></a>
### `tf.image.pad_to_bounding_box(image, offset_height, offset_width, target_height, target_width)` {#pad_to_bounding_box}
Pad `image` with zeros to the specified `height` and `width`.
@ -430,7 +396,7 @@ with zeros until it has dimensions `target_height`, `target_width`.
This op does nothing if `offset_*` is zero and the image already has size
`target_height` by `target_width`.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`image`</b>: 3-D tensor with shape `[height, width, channels]`
@ -439,11 +405,11 @@ This op does nothing if `offset_*` is zero and the image already has size
* <b>`target_height`</b>: Height of output image.
* <b>`target_width`</b>: Width of output image.
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
3-D tensor of shape `[target_height, target_width, channels]`
##### Raises: <a class="md-anchor" id="AUTOGENERATED-raises-"></a>
##### Raises:
* <b>`ValueError`</b>: If the shape of `image` is incompatible with the `offset_*` or
@ -452,7 +418,7 @@ This op does nothing if `offset_*` is zero and the image already has size
- - -
### `tf.image.crop_to_bounding_box(image, offset_height, offset_width, target_height, target_width)` <a class="md-anchor" id="crop_to_bounding_box"></a>
### `tf.image.crop_to_bounding_box(image, offset_height, offset_width, target_height, target_width)` {#crop_to_bounding_box}
Crops an image to a specified bounding box.
@ -461,7 +427,7 @@ returned image is at `offset_height, offset_width` in `image`, and its
lower-right corner is at
`offset_height + target_height, offset_width + target_width'.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`image`</b>: 3-D tensor with shape `[height, width, channels]`
@ -472,11 +438,11 @@ lower-right corner is at
* <b>`target_height`</b>: Height of the result.
* <b>`target_width`</b>: Width of the result.
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
3-D tensor of image with shape `[target_height, target_width, channels]`
##### Raises: <a class="md-anchor" id="AUTOGENERATED-raises-"></a>
##### Raises:
* <b>`ValueError`</b>: If the shape of `image` is incompatible with the `offset_*` or
@ -485,14 +451,14 @@ lower-right corner is at
- - -
### `tf.image.random_crop(image, size, seed=None, name=None)` <a class="md-anchor" id="random_crop"></a>
### `tf.image.random_crop(image, size, seed=None, name=None)` {#random_crop}
Randomly crops `image` to size `[target_height, target_width]`.
The offset of the output within `image` is uniformly random. `image` always
fully contains the result.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`image`</b>: 3-D tensor of shape `[height, width, channels]`
@ -502,14 +468,14 @@ fully contains the result.
for behavior.
* <b>`name`</b>: A name for this operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A cropped 3-D tensor of shape `[target_height, target_width, channels]`.
- - -
### `tf.image.extract_glimpse(input, size, offsets, centered=None, normalized=None, uniform_noise=None, name=None)` <a class="md-anchor" id="extract_glimpse"></a>
### `tf.image.extract_glimpse(input, size, offsets, centered=None, normalized=None, uniform_noise=None, name=None)` {#extract_glimpse}
Extracts a glimpse from the input tensor.
@ -530,7 +496,7 @@ The argument `normalized` and `centered` controls how the windows are built:
lower right corner is located at (1.0, 1.0) and the center is at (0, 0).
* If the coordinates are not normalized they are interpreted as numbers of pixels.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`input`</b>: A `Tensor` of type `float32`.
@ -553,7 +519,7 @@ The argument `normalized` and `centered` controls how the windows are built:
uniform distribution or a gaussian distribution.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor` of type `float32`.
A tensor representing the glimpses `[batch_size, glimpse_height,
@ -561,11 +527,11 @@ The argument `normalized` and `centered` controls how the windows are built:
## Flipping and Transposing <a class="md-anchor" id="AUTOGENERATED-flipping-and-transposing"></a>
## Flipping and Transposing
- - -
### `tf.image.flip_up_down(image)` <a class="md-anchor" id="flip_up_down"></a>
### `tf.image.flip_up_down(image)` {#flip_up_down}
Flip an image horizontally (upside down).
@ -574,16 +540,16 @@ Outputs the contents of `image` flipped along the first dimension, which is
See also `reverse()`.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`image`</b>: A 3-D tensor of shape `[height, width, channels].`
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A 3-D tensor of the same type and shape as `image`.
##### Raises: <a class="md-anchor" id="AUTOGENERATED-raises-"></a>
##### Raises:
* <b>`ValueError`</b>: if the shape of `image` not supported.
@ -591,14 +557,14 @@ See also `reverse()`.
- - -
### `tf.image.random_flip_up_down(image, seed=None)` <a class="md-anchor" id="random_flip_up_down"></a>
### `tf.image.random_flip_up_down(image, seed=None)` {#random_flip_up_down}
Randomly flips an image vertically (upside down).
With a 1 in 2 chance, outputs the contents of `image` flipped along the first
dimension, which is `height`. Otherwise output the image as-is.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`image`</b>: A 3-D tensor of shape `[height, width, channels].`
@ -606,11 +572,11 @@ dimension, which is `height`. Otherwise output the image as-is.
[`set_random_seed`](../../api_docs/python/constant_op.md#set_random_seed)
for behavior.
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A 3-D tensor of the same type and shape as `image`.
##### Raises: <a class="md-anchor" id="AUTOGENERATED-raises-"></a>
##### Raises:
* <b>`ValueError`</b>: if the shape of `image` not supported.
@ -619,7 +585,7 @@ dimension, which is `height`. Otherwise output the image as-is.
- - -
### `tf.image.flip_left_right(image)` <a class="md-anchor" id="flip_left_right"></a>
### `tf.image.flip_left_right(image)` {#flip_left_right}
Flip an image horizontally (left to right).
@ -628,16 +594,16 @@ Outputs the contents of `image` flipped along the second dimension, which is
See also `reverse()`.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`image`</b>: A 3-D tensor of shape `[height, width, channels].`
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A 3-D tensor of the same type and shape as `image`.
##### Raises: <a class="md-anchor" id="AUTOGENERATED-raises-"></a>
##### Raises:
* <b>`ValueError`</b>: if the shape of `image` not supported.
@ -645,14 +611,14 @@ See also `reverse()`.
- - -
### `tf.image.random_flip_left_right(image, seed=None)` <a class="md-anchor" id="random_flip_left_right"></a>
### `tf.image.random_flip_left_right(image, seed=None)` {#random_flip_left_right}
Randomly flip an image horizontally (left to right).
With a 1 in 2 chance, outputs the contents of `image` flipped along the
second dimension, which is `width`. Otherwise output the image as-is.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`image`</b>: A 3-D tensor of shape `[height, width, channels].`
@ -660,11 +626,11 @@ second dimension, which is `width`. Otherwise output the image as-is.
[`set_random_seed`](../../api_docs/python/constant_op.md#set_random_seed)
for behavior.
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A 3-D tensor of the same type and shape as `image`.
##### Raises: <a class="md-anchor" id="AUTOGENERATED-raises-"></a>
##### Raises:
* <b>`ValueError`</b>: if the shape of `image` not supported.
@ -673,29 +639,29 @@ second dimension, which is `width`. Otherwise output the image as-is.
- - -
### `tf.image.transpose_image(image)` <a class="md-anchor" id="transpose_image"></a>
### `tf.image.transpose_image(image)` {#transpose_image}
Transpose an image by swapping the first and second dimension.
See also `transpose()`.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`image`</b>: 3-D tensor of shape `[height, width, channels]`
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A 3-D tensor of shape `[width, height, channels]`
##### Raises: <a class="md-anchor" id="AUTOGENERATED-raises-"></a>
##### Raises:
* <b>`ValueError`</b>: if the shape of `image` not supported.
## Image Adjustments <a class="md-anchor" id="AUTOGENERATED-image-adjustments"></a>
## Image Adjustments
TensorFlow provides functions to adjust images in various ways: brightness,
contrast, hue, and saturation. Each adjustment can be done with predefined
@ -704,7 +670,7 @@ adjustments are often useful to expand a training set and reduce overfitting.
- - -
### `tf.image.adjust_brightness(image, delta, min_value=None, max_value=None)` <a class="md-anchor" id="adjust_brightness"></a>
### `tf.image.adjust_brightness(image, delta, min_value=None, max_value=None)` {#adjust_brightness}
Adjust the brightness of RGB or Grayscale images.
@ -716,7 +682,7 @@ clamped to `[min_value, max_value]`. Finally, the result is cast back to
If `min_value` or `max_value` are not given, they are set to the minimum and
maximum allowed values for `image.dtype` respectively.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`image`</b>: A tensor.
@ -724,14 +690,14 @@ maximum allowed values for `image.dtype` respectively.
* <b>`min_value`</b>: Minimum value for output.
* <b>`max_value`</b>: Maximum value for output.
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A tensor of the same shape and type as `image`.
- - -
### `tf.image.random_brightness(image, max_delta, seed=None)` <a class="md-anchor" id="random_brightness"></a>
### `tf.image.random_brightness(image, max_delta, seed=None)` {#random_brightness}
Adjust the brightness of images by a random factor.
@ -742,7 +708,7 @@ Note that `delta` is picked as a float. Because for integer type images,
the brightness adjusted result is rounded before casting, integer images may
have modifications in the range `[-max_delta,max_delta]`.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`image`</b>: 3-D tensor of shape `[height, width, channels]`.
@ -751,11 +717,11 @@ have modifications in the range `[-max_delta,max_delta]`.
[`set_random_seed`](../../api_docs/python/constant_op.md#set_random_seed)
for behavior.
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
3-D tensor of images of shape `[height, width, channels]`
##### Raises: <a class="md-anchor" id="AUTOGENERATED-raises-"></a>
##### Raises:
* <b>`ValueError`</b>: if max_delta is negative.
@ -764,7 +730,7 @@ have modifications in the range `[-max_delta,max_delta]`.
- - -
### `tf.image.adjust_contrast(images, contrast_factor, min_value=None, max_value=None)` <a class="md-anchor" id="adjust_contrast"></a>
### `tf.image.adjust_contrast(images, contrast_factor, min_value=None, max_value=None)` {#adjust_contrast}
Adjust contrast of RGB or grayscale images.
@ -785,7 +751,7 @@ minimum and maximum values for the data type of `images` respectively.
The contrast-adjusted image is always computed as `float`, and it is
cast back to its original type after clipping.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`images`</b>: Images to adjust. At least 3-D.
@ -793,11 +759,11 @@ cast back to its original type after clipping.
* <b>`min_value`</b>: Minimum value for clipping the adjusted pixels.
* <b>`max_value`</b>: Maximum value for clipping the adjusted pixels.
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
The constrast-adjusted image or images.
##### Raises: <a class="md-anchor" id="AUTOGENERATED-raises-"></a>
##### Raises:
* <b>`ValueError`</b>: if the arguments are invalid.
@ -805,14 +771,14 @@ cast back to its original type after clipping.
- - -
### `tf.image.random_contrast(image, lower, upper, seed=None)` <a class="md-anchor" id="random_contrast"></a>
### `tf.image.random_contrast(image, lower, upper, seed=None)` {#random_contrast}
Adjust the contrase of an image by a random factor.
Equivalent to `adjust_constrast()` but uses a `contrast_factor` randomly
picked in the interval `[lower, upper]`.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`image`</b>: 3-D tensor of shape `[height, width, channels]`.
@ -822,11 +788,11 @@ picked in the interval `[lower, upper]`.
[`set_random_seed`](../../api_docs/python/constant_op.md#set_random_seed)
for behavior.
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
3-D tensor of shape `[height, width, channels]`.
##### Raises: <a class="md-anchor" id="AUTOGENERATED-raises-"></a>
##### Raises:
* <b>`ValueError`</b>: if `upper <= lower` or if `lower < 0`.
@ -835,7 +801,7 @@ picked in the interval `[lower, upper]`.
- - -
### `tf.image.per_image_whitening(image)` <a class="md-anchor" id="per_image_whitening"></a>
### `tf.image.per_image_whitening(image)` {#per_image_whitening}
Linearly scales `image` to have zero mean and unit norm.
@ -850,16 +816,16 @@ Note that this implementation is limited:
* It only whitens based on the statistics of an individual image.
* It does not take into account the covariance structure.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`image`</b>: 3-D tensor of shape `[height, width, channels]`.
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
The whitened image with same shape as `image`.
##### Raises: <a class="md-anchor" id="AUTOGENERATED-raises-"></a>
##### Raises:
* <b>`ValueError`</b>: if the shape of 'image' is incompatible with this function.

View File

@ -1,6 +1,6 @@
<!-- This file is machine generated: DO NOT EDIT! -->
# TensorFlow Python reference documentation <a class="md-anchor" id="AUTOGENERATED-tensorflow-python-reference-documentation"></a>
# TensorFlow Python reference documentation
* **[Building Graphs](../../api_docs/python/framework.md)**:
* [`add_to_collection`](../../api_docs/python/framework.md#add_to_collection)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,60 +1,13 @@
<!-- This file is machine generated: DO NOT EDIT! -->
# Neural Network <a class="md-anchor" id="AUTOGENERATED-neural-network"></a>
# Neural Network
Note: Functions taking `Tensor` arguments can also take anything accepted by
[`tf.convert_to_tensor`](../../api_docs/python/framework.md#convert_to_tensor).
<!-- TOC-BEGIN This section is generated by neural network: DO NOT EDIT! -->
## Contents
### [Neural Network](#AUTOGENERATED-neural-network)
* [Activation Functions](#AUTOGENERATED-activation-functions)
* [`tf.nn.relu(features, name=None)`](#relu)
* [`tf.nn.relu6(features, name=None)`](#relu6)
* [`tf.nn.softplus(features, name=None)`](#softplus)
* [`tf.nn.dropout(x, keep_prob, noise_shape=None, seed=None, name=None)`](#dropout)
* [`tf.nn.bias_add(value, bias, name=None)`](#bias_add)
* [`tf.sigmoid(x, name=None)`](#sigmoid)
* [`tf.tanh(x, name=None)`](#tanh)
* [Convolution](#AUTOGENERATED-convolution)
* [`tf.nn.conv2d(input, filter, strides, padding, use_cudnn_on_gpu=None, name=None)`](#conv2d)
* [`tf.nn.depthwise_conv2d(input, filter, strides, padding, name=None)`](#depthwise_conv2d)
* [`tf.nn.separable_conv2d(input, depthwise_filter, pointwise_filter, strides, padding, name=None)`](#separable_conv2d)
* [Pooling](#AUTOGENERATED-pooling)
* [`tf.nn.avg_pool(value, ksize, strides, padding, name=None)`](#avg_pool)
* [`tf.nn.max_pool(value, ksize, strides, padding, name=None)`](#max_pool)
* [`tf.nn.max_pool_with_argmax(input, ksize, strides, padding, Targmax=None, name=None)`](#max_pool_with_argmax)
* [Normalization](#AUTOGENERATED-normalization)
* [`tf.nn.l2_normalize(x, dim, epsilon=1e-12, name=None)`](#l2_normalize)
* [`tf.nn.local_response_normalization(input, depth_radius=None, bias=None, alpha=None, beta=None, name=None)`](#local_response_normalization)
* [`tf.nn.moments(x, axes, name=None)`](#moments)
* [Losses](#AUTOGENERATED-losses)
* [`tf.nn.l2_loss(t, name=None)`](#l2_loss)
* [Classification](#AUTOGENERATED-classification)
* [`tf.nn.sigmoid_cross_entropy_with_logits(logits, targets, name=None)`](#sigmoid_cross_entropy_with_logits)
* [`tf.nn.softmax(logits, name=None)`](#softmax)
* [`tf.nn.softmax_cross_entropy_with_logits(logits, labels, name=None)`](#softmax_cross_entropy_with_logits)
* [Embeddings](#AUTOGENERATED-embeddings)
* [`tf.nn.embedding_lookup(params, ids, name=None)`](#embedding_lookup)
* [Evaluation](#AUTOGENERATED-evaluation)
* [`tf.nn.top_k(input, k, name=None)`](#top_k)
* [`tf.nn.in_top_k(predictions, targets, k, name=None)`](#in_top_k)
* [Candidate Sampling](#AUTOGENERATED-candidate-sampling)
* [Sampled Loss Functions](#AUTOGENERATED-sampled-loss-functions)
* [`tf.nn.nce_loss(weights, biases, inputs, labels, num_sampled, num_classes, num_true=1, sampled_values=None, remove_accidental_hits=False, name='nce_loss')`](#nce_loss)
* [`tf.nn.sampled_softmax_loss(weights, biases, inputs, labels, num_sampled, num_classes, num_true=1, sampled_values=None, remove_accidental_hits=True, name='sampled_softmax_loss')`](#sampled_softmax_loss)
* [Candidate Samplers](#AUTOGENERATED-candidate-samplers)
* [`tf.nn.uniform_candidate_sampler(true_classes, num_true, num_sampled, unique, range_max, seed=None, name=None)`](#uniform_candidate_sampler)
* [`tf.nn.log_uniform_candidate_sampler(true_classes, num_true, num_sampled, unique, range_max, seed=None, name=None)`](#log_uniform_candidate_sampler)
* [`tf.nn.learned_unigram_candidate_sampler(true_classes, num_true, num_sampled, unique, range_max, seed=None, name=None)`](#learned_unigram_candidate_sampler)
* [`tf.nn.fixed_unigram_candidate_sampler(true_classes, num_true, num_sampled, unique, range_max, vocab_file='', distortion=0.0, num_reserved_ids=0, num_shards=1, shard=0, unigrams=[], seed=None, name=None)`](#fixed_unigram_candidate_sampler)
* [Miscellaneous candidate sampling utilities](#AUTOGENERATED-miscellaneous-candidate-sampling-utilities)
* [`tf.nn.compute_accidental_hits(true_classes, sampled_candidates, num_true, seed=None, name=None)`](#compute_accidental_hits)
[TOC]
<!-- TOC-END This section was generated by neural network, THANKS FOR READING! -->
## Activation Functions <a class="md-anchor" id="AUTOGENERATED-activation-functions"></a>
## Activation Functions
The activation ops provide different types of nonlinearities for use in
neural networks. These include smooth nonlinearities (`sigmoid`,
@ -67,59 +20,59 @@ shape as the input tensor.
- - -
### `tf.nn.relu(features, name=None)` <a class="md-anchor" id="relu"></a>
### `tf.nn.relu(features, name=None)` {#relu}
Computes rectified linear: `max(features, 0)`.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`features`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `int64`, `uint8`, `int16`, `int8`.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor`. Has the same type as `features`.
- - -
### `tf.nn.relu6(features, name=None)` <a class="md-anchor" id="relu6"></a>
### `tf.nn.relu6(features, name=None)` {#relu6}
Computes Rectified Linear 6: `min(max(features, 0), 6)`.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`features`</b>: A `Tensor` with type `float`, `double`, `int32`, `int64`, `uint8`,
`int16`, or `int8`.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor` with the same type as `features`.
- - -
### `tf.nn.softplus(features, name=None)` <a class="md-anchor" id="softplus"></a>
### `tf.nn.softplus(features, name=None)` {#softplus}
Computes softplus: `log(exp(features) + 1)`.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`features`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `int64`, `uint8`, `int16`, `int8`.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor`. Has the same type as `features`.
- - -
### `tf.nn.dropout(x, keep_prob, noise_shape=None, seed=None, name=None)` <a class="md-anchor" id="dropout"></a>
### `tf.nn.dropout(x, keep_prob, noise_shape=None, seed=None, name=None)` {#dropout}
Computes dropout.
@ -135,7 +88,7 @@ will make independent decisions. For example, if `shape(x) = [k, l, m, n]`
and `noise_shape = [k, 1, 1, n]`, each batch and channel component will be
kept independently and each row and column will be kept or not kept together.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`x`</b>: A tensor.
@ -148,11 +101,11 @@ kept independently and each row and column will be kept or not kept together.
for behavior.
* <b>`name`</b>: A name for this operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A Tensor of the same shape of `x`.
##### Raises: <a class="md-anchor" id="AUTOGENERATED-raises-"></a>
##### Raises:
* <b>`ValueError`</b>: If `keep_prob` is not in `(0, 1]`.
@ -160,7 +113,7 @@ kept independently and each row and column will be kept or not kept together.
- - -
### `tf.nn.bias_add(value, bias, name=None)` <a class="md-anchor" id="bias_add"></a>
### `tf.nn.bias_add(value, bias, name=None)` {#bias_add}
Adds `bias` to `value`.
@ -169,7 +122,7 @@ Broadcasting is supported, so `value` may have any number of dimensions.
Unlike `tf.add`, the type of `bias` is allowed to differ from `value` in the
case where both types are quantized.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`value`</b>: A `Tensor` with type `float`, `double`, `int64`, `int32`, `uint8`,
@ -179,27 +132,27 @@ case where both types are quantized.
in which case a different quantized type may be used.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor` with the same type as `value`.
- - -
### `tf.sigmoid(x, name=None)` <a class="md-anchor" id="sigmoid"></a>
### `tf.sigmoid(x, name=None)` {#sigmoid}
Computes sigmoid of `x` element-wise.
Specifically, `y = 1 / (1 + exp(-x))`.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`x`</b>: A Tensor with type `float`, `double`, `int32`, `complex64`, `int64`,
or `qint32`.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A Tensor with the same type as `x` if `x.dtype != qint32`
otherwise the return type is `quint8`.
@ -207,25 +160,25 @@ Specifically, `y = 1 / (1 + exp(-x))`.
- - -
### `tf.tanh(x, name=None)` <a class="md-anchor" id="tanh"></a>
### `tf.tanh(x, name=None)` {#tanh}
Computes hyperbolic tangent of `x` element-wise.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`x`</b>: A Tensor with type `float`, `double`, `int32`, `complex64`, `int64`,
or `qint32`.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A Tensor with the same type as `x` if `x.dtype != qint32` otherwise
the return type is `quint8`.
## Convolution <a class="md-anchor" id="AUTOGENERATED-convolution"></a>
## Convolution
The convolution ops sweep a 2-D filter over a batch of images, applying the
filter to each window of each image of the appropriate size. The different
@ -295,7 +248,7 @@ concatenated.
- - -
### `tf.nn.conv2d(input, filter, strides, padding, use_cudnn_on_gpu=None, name=None)` <a class="md-anchor" id="conv2d"></a>
### `tf.nn.conv2d(input, filter, strides, padding, use_cudnn_on_gpu=None, name=None)` {#conv2d}
Computes a 2-D convolution given 4-D `input` and `filter` tensors.
@ -321,7 +274,7 @@ In detail,
Must have `strides[0] = strides[3] = 1`. For the most common case of the same
horizontal and vertices strides, `strides = [1, stride, stride, 1]`.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`input`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`.
@ -334,14 +287,14 @@ horizontal and vertices strides, `strides = [1, stride, stride, 1]`.
* <b>`use_cudnn_on_gpu`</b>: An optional `bool`. Defaults to `True`.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor`. Has the same type as `input`.
- - -
### `tf.nn.depthwise_conv2d(input, filter, strides, padding, name=None)` <a class="md-anchor" id="depthwise_conv2d"></a>
### `tf.nn.depthwise_conv2d(input, filter, strides, padding, name=None)` {#depthwise_conv2d}
Depthwise 2-D convolution.
@ -362,7 +315,7 @@ In detail,
Must have `strides[0] = strides[3] = 1`. For the most common case of the
same horizontal and vertical strides, `strides = [1, stride, stride, 1]`.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`input`</b>: 4-D with shape `[batch, in_height, in_width, in_channels]`.
@ -373,7 +326,7 @@ same horizontal and vertical strides, `strides = [1, stride, stride, 1]`.
* <b>`padding`</b>: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
* <b>`name`</b>: A name for this operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A 4-D `Tensor` of shape
`[batch, out_height, out_width, in_channels * channel_multiplier].`
@ -381,7 +334,7 @@ same horizontal and vertical strides, `strides = [1, stride, stride, 1]`.
- - -
### `tf.nn.separable_conv2d(input, depthwise_filter, pointwise_filter, strides, padding, name=None)` <a class="md-anchor" id="separable_conv2d"></a>
### `tf.nn.separable_conv2d(input, depthwise_filter, pointwise_filter, strides, padding, name=None)` {#separable_conv2d}
2-D convolution with separable filters.
@ -402,7 +355,7 @@ the pointwise convolution has implicit strides of `[1, 1, 1, 1]`. Must have
`strides[0] = strides[3] = 1`. For the most common case of the same
horizontal and vertical strides, `strides = [1, stride, stride, 1]`.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`input`</b>: 4-D `Tensor` with shape `[batch, in_height, in_width, in_channels]`.
@ -417,13 +370,13 @@ horizontal and vertical strides, `strides = [1, stride, stride, 1]`.
* <b>`padding`</b>: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
* <b>`name`</b>: A name for this operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A 4-D `Tensor` of shape `[batch, out_height, out_width, out_channels]`.
## Pooling <a class="md-anchor" id="AUTOGENERATED-pooling"></a>
## Pooling
The pooling ops sweep a rectangular window over the input tensor, computing a
reduction operation for each window (average, max, or max with argmax). Each
@ -440,14 +393,14 @@ to the `Convolution` section for details about the padding calculation.
- - -
### `tf.nn.avg_pool(value, ksize, strides, padding, name=None)` <a class="md-anchor" id="avg_pool"></a>
### `tf.nn.avg_pool(value, ksize, strides, padding, name=None)` {#avg_pool}
Performs the average pooling on the input.
Each entry in `output` is the mean of the corresponding size `ksize`
window in `value`.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`value`</b>: A 4-D `Tensor` of shape `[batch, height, width, channels]` and type
@ -460,18 +413,18 @@ window in `value`.
* <b>`padding`</b>: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
* <b>`name`</b>: Optional name for the operation.
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor` with the same type as `value`. The average pooled output tensor.
- - -
### `tf.nn.max_pool(value, ksize, strides, padding, name=None)` <a class="md-anchor" id="max_pool"></a>
### `tf.nn.max_pool(value, ksize, strides, padding, name=None)` {#max_pool}
Performs the max pooling on the input.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`value`</b>: A 4-D `Tensor` with shape `[batch, height, width, channels]` and
@ -483,14 +436,14 @@ Performs the max pooling on the input.
* <b>`padding`</b>: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
* <b>`name`</b>: Optional name for the operation.
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor` with the same type as `value`. The max pooled output tensor.
- - -
### `tf.nn.max_pool_with_argmax(input, ksize, strides, padding, Targmax=None, name=None)` <a class="md-anchor" id="max_pool_with_argmax"></a>
### `tf.nn.max_pool_with_argmax(input, ksize, strides, padding, Targmax=None, name=None)` {#max_pool_with_argmax}
Performs max pooling on the input and outputs both max values and indices.
@ -498,7 +451,7 @@ The indices in `argmax` are flattened, so that a maximum value at position
`[b, y, x, c]` becomes flattened index
`((b * height + y) * width + x) * channels + c`.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`input`</b>: A `Tensor` of type `float32`.
@ -513,7 +466,7 @@ The indices in `argmax` are flattened, so that a maximum value at position
* <b>`Targmax`</b>: An optional `tf.DType` from: `tf.int32, tf.int64`. Defaults to `tf.int64`.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A tuple of `Tensor` objects (output, argmax).
@ -522,14 +475,14 @@ The indices in `argmax` are flattened, so that a maximum value at position
## Normalization <a class="md-anchor" id="AUTOGENERATED-normalization"></a>
## Normalization
Normalization is useful to prevent neurons from saturating when inputs may
have varying scale, and to aid generalization.
- - -
### `tf.nn.l2_normalize(x, dim, epsilon=1e-12, name=None)` <a class="md-anchor" id="l2_normalize"></a>
### `tf.nn.l2_normalize(x, dim, epsilon=1e-12, name=None)` {#l2_normalize}
Normalizes along dimension `dim` using an L2 norm.
@ -540,7 +493,7 @@ For a 1-D tensor with `dim = 0`, computes
For `x` with more dimensions, independently normalizes each 1-D slice along
dimension `dim`.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`x`</b>: A `Tensor`.
@ -549,14 +502,14 @@ dimension `dim`.
divisor if `norm < sqrt(epsilon)`.
* <b>`name`</b>: A name for this operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor` with the same shape as `x`.
- - -
### `tf.nn.local_response_normalization(input, depth_radius=None, bias=None, alpha=None, beta=None, name=None)` <a class="md-anchor" id="local_response_normalization"></a>
### `tf.nn.local_response_normalization(input, depth_radius=None, bias=None, alpha=None, beta=None, name=None)` {#local_response_normalization}
Local Response Normalization.
@ -573,7 +526,7 @@ For details, see [Krizhevsky et al., ImageNet classification with deep
convolutional neural networks (NIPS 2012)]
(http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks).
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`input`</b>: A `Tensor` of type `float32`. 4-D.
@ -586,14 +539,14 @@ convolutional neural networks (NIPS 2012)]
* <b>`beta`</b>: An optional `float`. Defaults to `0.5`. An exponent.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor` of type `float32`.
- - -
### `tf.nn.moments(x, axes, name=None)` <a class="md-anchor" id="moments"></a>
### `tf.nn.moments(x, axes, name=None)` {#moments}
Calculate the mean and variance of `x`.
@ -605,7 +558,7 @@ For so-called "global normalization" needed for convolutional filters pass
`axes=[0, 1, 2]` (batch, height, width). For batch normalization pass
`axes=[0]` (batch).
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`x`</b>: A `Tensor`.
@ -613,13 +566,13 @@ For so-called "global normalization" needed for convolutional filters pass
variance.
* <b>`name`</b>: Name used to scope the operations that compute the moments.
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
Two `Tensor` objects: `mean` and `variance`.
## Losses <a class="md-anchor" id="AUTOGENERATED-losses"></a>
## Losses
The loss ops measure error between two tensors, or between a tensor and zero.
These can be used for measuring accuracy of a network in a regression task
@ -627,7 +580,7 @@ or for regularization purposes (weight decay).
- - -
### `tf.nn.l2_loss(t, name=None)` <a class="md-anchor" id="l2_loss"></a>
### `tf.nn.l2_loss(t, name=None)` {#l2_loss}
L2 Loss.
@ -635,26 +588,26 @@ Computes half the L2 norm of a tensor without the `sqrt`:
output = sum(t ** 2) / 2
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`t`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `qint8`, `quint8`, `qint32`.
Typically 2-D, but may have any dimensions.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor`. Has the same type as `t`. 0-D.
## Classification <a class="md-anchor" id="AUTOGENERATED-classification"></a>
## Classification
TensorFlow provides several operations that help you perform classification.
- - -
### `tf.nn.sigmoid_cross_entropy_with_logits(logits, targets, name=None)` <a class="md-anchor" id="sigmoid_cross_entropy_with_logits"></a>
### `tf.nn.sigmoid_cross_entropy_with_logits(logits, targets, name=None)` {#sigmoid_cross_entropy_with_logits}
Computes sigmoid cross entropy given `logits`.
@ -673,14 +626,14 @@ To ensure stability and avoid overflow, the implementation uses
`logits` and `targets` must have the same type and shape.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`logits`</b>: A `Tensor` of type `float32` or `float64`.
* <b>`targets`</b>: A `Tensor` of the same type and shape as `logits`.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor` of the same shape as `logits` with the componentwise
logistic losses.
@ -688,7 +641,7 @@ To ensure stability and avoid overflow, the implementation uses
- - -
### `tf.nn.softmax(logits, name=None)` <a class="md-anchor" id="softmax"></a>
### `tf.nn.softmax(logits, name=None)` {#softmax}
Computes softmax activations.
@ -696,21 +649,21 @@ For each batch `i` and class `j` we have
softmax[i, j] = exp(logits[i, j]) / sum(exp(logits[i]))
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`logits`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`.
2-D with shape `[batch_size, num_classes]`.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor`. Has the same type as `logits`. Same shape as `logits`.
- - -
### `tf.nn.softmax_cross_entropy_with_logits(logits, labels, name=None)` <a class="md-anchor" id="softmax_cross_entropy_with_logits"></a>
### `tf.nn.softmax_cross_entropy_with_logits(logits, labels, name=None)` {#softmax_cross_entropy_with_logits}
Computes softmax cross entropy between `logits` and `labels`.
@ -726,28 +679,28 @@ output of `softmax`, as it will produce incorrect results.
`logits` and `labels` must have the same shape `[batch_size, num_classes]`
and the same dtype (either `float32` or `float64`).
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`logits`</b>: Unscaled log probabilities.
* <b>`labels`</b>: Each row `labels[i]` must be a valid probability distribution.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A 1-D `Tensor` of length `batch_size` of the same type as `logits` with the
softmax cross entropy loss.
## Embeddings <a class="md-anchor" id="AUTOGENERATED-embeddings"></a>
## Embeddings
TensorFlow provides library support for looking up values in embedding
tensors.
- - -
### `tf.nn.embedding_lookup(params, ids, name=None)` <a class="md-anchor" id="embedding_lookup"></a>
### `tf.nn.embedding_lookup(params, ids, name=None)` {#embedding_lookup}
Looks up `ids` in a list of embedding tensors.
@ -763,7 +716,7 @@ then used to look up the slice `params[p][id // len(params), ...]`.
The results of the lookup are then concatenated into a dense
tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`params`</b>: A list of tensors with the same shape and type.
@ -771,25 +724,25 @@ tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.
up in `params`.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor` with the same type as the tensors in `params`.
##### Raises: <a class="md-anchor" id="AUTOGENERATED-raises-"></a>
##### Raises:
* <b>`ValueError`</b>: If `params` is empty.
## Evaluation <a class="md-anchor" id="AUTOGENERATED-evaluation"></a>
## Evaluation
The evaluation ops are useful for measuring the performance of a network.
Since they are nondifferentiable, they are typically used at evaluation time.
- - -
### `tf.nn.top_k(input, k, name=None)` <a class="md-anchor" id="top_k"></a>
### `tf.nn.top_k(input, k, name=None)` {#top_k}
Returns the values and indices of the k largest elements for each row.
@ -799,7 +752,7 @@ Returns the values and indices of the k largest elements for each row.
such that \\(input_{i, indices_{i, j}} = values_{i, j}\\). If two
elements are equal, the lower-index element appears first.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`input`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `int64`, `uint8`, `int16`, `int8`.
@ -808,7 +761,7 @@ elements are equal, the lower-index element appears first.
Number of top elements to look for within each row
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A tuple of `Tensor` objects (values, indices).
@ -819,7 +772,7 @@ elements are equal, the lower-index element appears first.
- - -
### `tf.nn.in_top_k(predictions, targets, k, name=None)` <a class="md-anchor" id="in_top_k"></a>
### `tf.nn.in_top_k(predictions, targets, k, name=None)` {#in_top_k}
Says whether the targets are in the top K predictions.
@ -838,7 +791,7 @@ More formally, let
$$out_i = predictions_{i, targets_i} \in TopKIncludingTies(predictions_i)$$
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`predictions`</b>: A `Tensor` of type `float32`. A batch_size x classes tensor
@ -846,13 +799,13 @@ $$out_i = predictions_{i, targets_i} \in TopKIncludingTies(predictions_i)$$
* <b>`k`</b>: An `int`. Number of top elements to look at for computing precision
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor` of type `bool`. Computed Precision at k as a bool Tensor
## Candidate Sampling <a class="md-anchor" id="AUTOGENERATED-candidate-sampling"></a>
## Candidate Sampling
Do you want to train a multiclass or multilabel model with thousands
or millions of output classes (for example, a language model with a
@ -865,13 +818,13 @@ only considering a small randomly-chosen subset of contrastive classes
See our [Candidate Sampling Algorithms Reference]
(../../extras/candidate_sampling.pdf)
### Sampled Loss Functions <a class="md-anchor" id="AUTOGENERATED-sampled-loss-functions"></a>
### Sampled Loss Functions
TensorFlow provides the following sampled loss functions for faster training.
- - -
### `tf.nn.nce_loss(weights, biases, inputs, labels, num_sampled, num_classes, num_true=1, sampled_values=None, remove_accidental_hits=False, name='nce_loss')` <a class="md-anchor" id="nce_loss"></a>
### `tf.nn.nce_loss(weights, biases, inputs, labels, num_sampled, num_classes, num_true=1, sampled_values=None, remove_accidental_hits=False, name='nce_loss')` {#nce_loss}
Computes and returns the noise-contrastive estimation training loss.
@ -891,7 +844,7 @@ For now, if you have a variable number of target classes, you can pad them
out to a constant number by either repeating them or by padding
with an otherwise unused class.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`weights`</b>: A `Tensor` of shape [num_classes, dim]. The class embeddings.
@ -915,14 +868,14 @@ with an otherwise unused class.
Default is False.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A batch_size 1-D tensor of per-example NCE losses.
- - -
### `tf.nn.sampled_softmax_loss(weights, biases, inputs, labels, num_sampled, num_classes, num_true=1, sampled_values=None, remove_accidental_hits=True, name='sampled_softmax_loss')` <a class="md-anchor" id="sampled_softmax_loss"></a>
### `tf.nn.sampled_softmax_loss(weights, biases, inputs, labels, num_sampled, num_classes, num_true=1, sampled_values=None, remove_accidental_hits=True, name='sampled_softmax_loss')` {#sampled_softmax_loss}
Computes and returns the sampled softmax training loss.
@ -940,7 +893,7 @@ See our [Candidate Sampling Algorithms Reference]
Also see Section 3 of http://arxiv.org/abs/1412.2007 for the math.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`weights`</b>: A `Tensor` of shape [num_classes, dim]. The class embeddings.
@ -961,20 +914,20 @@ Also see Section 3 of http://arxiv.org/abs/1412.2007 for the math.
True.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A batch_size 1-D tensor of per-example sampled softmax losses.
### Candidate Samplers <a class="md-anchor" id="AUTOGENERATED-candidate-samplers"></a>
### Candidate Samplers
TensorFlow provides the following samplers for randomly sampling candidate
classes when using one of the sampled loss functions above.
- - -
### `tf.nn.uniform_candidate_sampler(true_classes, num_true, num_sampled, unique, range_max, seed=None, name=None)` <a class="md-anchor" id="uniform_candidate_sampler"></a>
### `tf.nn.uniform_candidate_sampler(true_classes, num_true, num_sampled, unique, range_max, seed=None, name=None)` {#uniform_candidate_sampler}
Samples a set of classes using a uniform base distribution.
@ -998,7 +951,7 @@ document](http://www.tensorflow.org/extras/candidate_sampling.pdf).
If `unique=True`, then these are post-rejection probabilities and we
compute them approximately.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`true_classes`</b>: A `Tensor` of type `int64` and shape `[batch_size,
@ -1011,7 +964,7 @@ compute them approximately.
* <b>`seed`</b>: An `int`. An operation-specific seed. Default is 0.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
* <b>`sampled_candidates`</b>: A tensor of type `int64` and shape `[num_sampled]`.
@ -1026,7 +979,7 @@ compute them approximately.
- - -
### `tf.nn.log_uniform_candidate_sampler(true_classes, num_true, num_sampled, unique, range_max, seed=None, name=None)` <a class="md-anchor" id="log_uniform_candidate_sampler"></a>
### `tf.nn.log_uniform_candidate_sampler(true_classes, num_true, num_sampled, unique, range_max, seed=None, name=None)` {#log_uniform_candidate_sampler}
Samples a set of classes using a log-uniform (Zipfian) base distribution.
@ -1057,7 +1010,7 @@ document](http://www.tensorflow.org/extras/candidate_sampling.pdf).
If `unique=True`, then these are post-rejection probabilities and we
compute them approximately.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`true_classes`</b>: A `Tensor` of type `int64` and shape `[batch_size,
@ -1070,7 +1023,7 @@ compute them approximately.
* <b>`seed`</b>: An `int`. An operation-specific seed. Default is 0.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
* <b>`sampled_candidates`</b>: A tensor of type `int64` and shape `[num_sampled]`.
@ -1085,7 +1038,7 @@ compute them approximately.
- - -
### `tf.nn.learned_unigram_candidate_sampler(true_classes, num_true, num_sampled, unique, range_max, seed=None, name=None)` <a class="md-anchor" id="learned_unigram_candidate_sampler"></a>
### `tf.nn.learned_unigram_candidate_sampler(true_classes, num_true, num_sampled, unique, range_max, seed=None, name=None)` {#learned_unigram_candidate_sampler}
Samples a set of classes from a distribution learned during training.
@ -1113,7 +1066,7 @@ document](http://www.tensorflow.org/extras/candidate_sampling.pdf).
If `unique=True`, then these are post-rejection probabilities and we
compute them approximately.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`true_classes`</b>: A `Tensor` of type `int64` and shape `[batch_size,
@ -1126,7 +1079,7 @@ compute them approximately.
* <b>`seed`</b>: An `int`. An operation-specific seed. Default is 0.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
* <b>`sampled_candidates`</b>: A tensor of type `int64` and shape `[num_sampled]`.
@ -1141,7 +1094,7 @@ compute them approximately.
- - -
### `tf.nn.fixed_unigram_candidate_sampler(true_classes, num_true, num_sampled, unique, range_max, vocab_file='', distortion=0.0, num_reserved_ids=0, num_shards=1, shard=0, unigrams=[], seed=None, name=None)` <a class="md-anchor" id="fixed_unigram_candidate_sampler"></a>
### `tf.nn.fixed_unigram_candidate_sampler(true_classes, num_true, num_sampled, unique, range_max, vocab_file='', distortion=0.0, num_reserved_ids=0, num_shards=1, shard=0, unigrams=[], seed=None, name=None)` {#fixed_unigram_candidate_sampler}
Samples a set of classes using the provided (fixed) base distribution.
@ -1166,7 +1119,7 @@ document](http://www.tensorflow.org/extras/candidate_sampling.pdf).
If `unique=True`, then these are post-rejection probabilities and we
compute them approximately.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`true_classes`</b>: A `Tensor` of type `int64` and shape `[batch_size,
@ -1204,7 +1157,7 @@ compute them approximately.
* <b>`seed`</b>: An `int`. An operation-specific seed. Default is 0.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
* <b>`sampled_candidates`</b>: A tensor of type `int64` and shape `[num_sampled]`.
@ -1218,11 +1171,11 @@ compute them approximately.
### Miscellaneous candidate sampling utilities <a class="md-anchor" id="AUTOGENERATED-miscellaneous-candidate-sampling-utilities"></a>
### Miscellaneous candidate sampling utilities
- - -
### `tf.nn.compute_accidental_hits(true_classes, sampled_candidates, num_true, seed=None, name=None)` <a class="md-anchor" id="compute_accidental_hits"></a>
### `tf.nn.compute_accidental_hits(true_classes, sampled_candidates, num_true, seed=None, name=None)` {#compute_accidental_hits}
Compute the ids of positions in sampled_candidates matching true_classes.
@ -1246,7 +1199,7 @@ operation, then added to the logits of the sampled classes. This
removes the contradictory effect of accidentally sampling the true
target classes as noise classes for the same example.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`true_classes`</b>: A `Tensor` of type `int64` and shape `[batch_size,
@ -1257,7 +1210,7 @@ target classes as noise classes for the same example.
* <b>`seed`</b>: An `int`. An operation-specific seed. Default is 0.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
* <b>`indices`</b>: A `Tensor` of type `int32` and shape `[num_accidental_hits]`.

View File

@ -1,18 +1,9 @@
<!-- This file is machine generated: DO NOT EDIT! -->
# Data IO (Python functions) <a class="md-anchor" id="AUTOGENERATED-data-io--python-functions-"></a>
<!-- TOC-BEGIN This section is generated by neural network: DO NOT EDIT! -->
## Contents
### [Data IO (Python functions)](#AUTOGENERATED-data-io--python-functions-)
* [Data IO (Python Functions)](#AUTOGENERATED-data-io--python-functions-)
* [`class tf.python_io.TFRecordWriter`](#TFRecordWriter)
* [`tf.python_io.tf_record_iterator(path)`](#tf_record_iterator)
* [TFRecords Format Details](#AUTOGENERATED-tfrecords-format-details)
# Data IO (Python functions)
[TOC]
<!-- TOC-END This section was generated by neural network, THANKS FOR READING! -->
## Data IO (Python Functions) <a class="md-anchor" id="AUTOGENERATED-data-io--python-functions-"></a>
## Data IO (Python Functions)
A TFRecords file represents a sequence of (binary) strings. The format is not
random access, so it is suitable for streaming large amounts of data but not
@ -20,7 +11,7 @@ suitable if fast sharding or other non-sequential access is desired.
- - -
### `class tf.python_io.TFRecordWriter` <a class="md-anchor" id="TFRecordWriter"></a>
### `class tf.python_io.TFRecordWriter` {#TFRecordWriter}
A class to write records to a TFRecords file.
@ -29,16 +20,16 @@ in `with` blocks like a normal file.
- - -
#### `tf.python_io.TFRecordWriter.__init__(path)` <a class="md-anchor" id="TFRecordWriter.__init__"></a>
#### `tf.python_io.TFRecordWriter.__init__(path)` {#TFRecordWriter.__init__}
Opens file `path` and creates a `TFRecordWriter` writing to it.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`path`</b>: The path to the TFRecords file.
##### Raises: <a class="md-anchor" id="AUTOGENERATED-raises-"></a>
##### Raises:
* <b>`IOError`</b>: If `path` cannot be opened for writing.
@ -46,11 +37,11 @@ Opens file `path` and creates a `TFRecordWriter` writing to it.
- - -
#### `tf.python_io.TFRecordWriter.write(record)` <a class="md-anchor" id="TFRecordWriter.write"></a>
#### `tf.python_io.TFRecordWriter.write(record)` {#TFRecordWriter.write}
Write a string record to the file.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`record`</b>: str
@ -58,7 +49,7 @@ Write a string record to the file.
- - -
#### `tf.python_io.TFRecordWriter.close()` <a class="md-anchor" id="TFRecordWriter.close"></a>
#### `tf.python_io.TFRecordWriter.close()` {#TFRecordWriter.close}
Close the file.
@ -66,20 +57,20 @@ Close the file.
- - -
### `tf.python_io.tf_record_iterator(path)` <a class="md-anchor" id="tf_record_iterator"></a>
### `tf.python_io.tf_record_iterator(path)` {#tf_record_iterator}
An iterator that read the records from a TFRecords file.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`path`</b>: The path to the TFRecords file.
##### Yields: <a class="md-anchor" id="AUTOGENERATED-yields-"></a>
##### Yields:
Strings.
##### Raises: <a class="md-anchor" id="AUTOGENERATED-raises-"></a>
##### Raises:
* <b>`IOError`</b>: If `path` cannot be opened for reading.
@ -88,7 +79,7 @@ An iterator that read the records from a TFRecords file.
- - -
### TFRecords Format Details <a class="md-anchor" id="AUTOGENERATED-tfrecords-format-details"></a>
### TFRecords Format Details
A TFRecords file contains a sequence of strings with CRC hashes. Each record
has the format

View File

@ -1,30 +1,13 @@
<!-- This file is machine generated: DO NOT EDIT! -->
# Sparse Tensors <a class="md-anchor" id="AUTOGENERATED-sparse-tensors"></a>
# Sparse Tensors
Note: Functions taking `Tensor` arguments can also take anything accepted by
[`tf.convert_to_tensor`](../../api_docs/python/framework.md#convert_to_tensor).
<!-- TOC-BEGIN This section is generated by neural network: DO NOT EDIT! -->
## Contents
### [Sparse Tensors](#AUTOGENERATED-sparse-tensors)
* [Sparse Tensor Representation](#AUTOGENERATED-sparse-tensor-representation)
* [`class tf.SparseTensor`](#SparseTensor)
* [`class tf.SparseTensorValue`](#SparseTensorValue)
* [Sparse to Dense Conversion](#AUTOGENERATED-sparse-to-dense-conversion)
* [`tf.sparse_to_dense(sparse_indices, output_shape, sparse_values, default_value, name=None)`](#sparse_to_dense)
* [`tf.sparse_tensor_to_dense(sp_input, default_value, name=None)`](#sparse_tensor_to_dense)
* [`tf.sparse_to_indicator(sp_input, vocab_size, name=None)`](#sparse_to_indicator)
* [Manipulation](#AUTOGENERATED-manipulation)
* [`tf.sparse_concat(concat_dim, sp_inputs, name=None)`](#sparse_concat)
* [`tf.sparse_reorder(sp_input, name=None)`](#sparse_reorder)
* [`tf.sparse_retain(sp_input, to_retain)`](#sparse_retain)
* [`tf.sparse_fill_empty_rows(sp_input, default_value, name=None)`](#sparse_fill_empty_rows)
[TOC]
<!-- TOC-END This section was generated by neural network, THANKS FOR READING! -->
## Sparse Tensor Representation <a class="md-anchor" id="AUTOGENERATED-sparse-tensor-representation"></a>
## Sparse Tensor Representation
Tensorflow supports a `SparseTensor` representation for data that is sparse
in multiple dimensions. Contrast this representation with `IndexedSlices`,
@ -33,7 +16,7 @@ dimension, and dense along all other dimensions.
- - -
### `class tf.SparseTensor` <a class="md-anchor" id="SparseTensor"></a>
### `class tf.SparseTensor` {#SparseTensor}
Represents a sparse tensor.
@ -81,92 +64,92 @@ represents the dense tensor
- - -
#### `tf.SparseTensor.__init__(indices, values, shape)` <a class="md-anchor" id="SparseTensor.__init__"></a>
#### `tf.SparseTensor.__init__(indices, values, shape)` {#SparseTensor.__init__}
Creates a `SparseTensor`.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`indices`</b>: A 2-D int64 tensor of shape `[N, ndims]`.
* <b>`values`</b>: A 1-D tensor of any type and shape `[N]`.
* <b>`dense_shape`</b>: A 1-D int64 tensor of shape `[ndims]`.
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `SparseTensor`
- - -
#### `tf.SparseTensor.indices` <a class="md-anchor" id="SparseTensor.indices"></a>
#### `tf.SparseTensor.indices` {#SparseTensor.indices}
The indices of non-zero values in the represented dense tensor.
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A 2-D Tensor of int64 with shape `[N, ndims]`, where `N` is the
number of non-zero values in the tensor, and `ndims` is the rank.
- - -
#### `tf.SparseTensor.values` <a class="md-anchor" id="SparseTensor.values"></a>
#### `tf.SparseTensor.values` {#SparseTensor.values}
The non-zero values in the represented dense tensor.
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A 1-D Tensor of any data type.
- - -
#### `tf.SparseTensor.dtype` <a class="md-anchor" id="SparseTensor.dtype"></a>
#### `tf.SparseTensor.dtype` {#SparseTensor.dtype}
The `DType` of elements in this tensor.
- - -
#### `tf.SparseTensor.shape` <a class="md-anchor" id="SparseTensor.shape"></a>
#### `tf.SparseTensor.shape` {#SparseTensor.shape}
A 1-D Tensor of int64 representing the shape of the dense tensor.
- - -
#### `tf.SparseTensor.graph` <a class="md-anchor" id="SparseTensor.graph"></a>
#### `tf.SparseTensor.graph` {#SparseTensor.graph}
The `Graph` that contains the index, value, and shape tensors.
- - -
### `class tf.SparseTensorValue` <a class="md-anchor" id="SparseTensorValue"></a>
### `class tf.SparseTensorValue` {#SparseTensorValue}
SparseTensorValue(indices, values, shape)
- - -
#### `tf.SparseTensorValue.indices` <a class="md-anchor" id="SparseTensorValue.indices"></a>
#### `tf.SparseTensorValue.indices` {#SparseTensorValue.indices}
Alias for field number 0
- - -
#### `tf.SparseTensorValue.shape` <a class="md-anchor" id="SparseTensorValue.shape"></a>
#### `tf.SparseTensorValue.shape` {#SparseTensorValue.shape}
Alias for field number 2
- - -
#### `tf.SparseTensorValue.values` <a class="md-anchor" id="SparseTensorValue.values"></a>
#### `tf.SparseTensorValue.values` {#SparseTensorValue.values}
Alias for field number 1
## Sparse to Dense Conversion <a class="md-anchor" id="AUTOGENERATED-sparse-to-dense-conversion"></a>
## Sparse to Dense Conversion
- - -
### `tf.sparse_to_dense(sparse_indices, output_shape, sparse_values, default_value, name=None)` <a class="md-anchor" id="sparse_to_dense"></a>
### `tf.sparse_to_dense(sparse_indices, output_shape, sparse_values, default_value, name=None)` {#sparse_to_dense}
Converts a sparse representation into a dense tensor.
@ -186,7 +169,7 @@ dense[sparse_indices[i][0], ..., sparse_indices[i][d-1]] = sparse_values[i]
All other values in `dense` are set to `default_value`. If `sparse_values` is a
scalar, all sparse indices are set to this single value.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`sparse_indices`</b>: A `Tensor`. Must be one of the following types: `int32`, `int64`.
@ -202,7 +185,7 @@ scalar, all sparse indices are set to this single value.
`sparse_indices`.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor`. Has the same type as `sparse_values`.
Dense output tensor of shape `output_shape`.
@ -210,7 +193,7 @@ scalar, all sparse indices are set to this single value.
- - -
### `tf.sparse_tensor_to_dense(sp_input, default_value, name=None)` <a class="md-anchor" id="sparse_tensor_to_dense"></a>
### `tf.sparse_tensor_to_dense(sp_input, default_value, name=None)` {#sparse_tensor_to_dense}
Converts a `SparseTensor` into a dense tensor.
@ -229,7 +212,7 @@ string tensor with values:
[x x x x x]
[c x x x x]]
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`sp_input`</b>: The input `SparseTensor`.
@ -237,13 +220,13 @@ string tensor with values:
`sp_input`.
* <b>`name`</b>: A name prefix for the returned tensors (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A dense tensor with shape `sp_input.shape` and values specified by
the non-empty values in `sp_input`. Indices not in `sp_input` are assigned
`default_value`.
##### Raises: <a class="md-anchor" id="AUTOGENERATED-raises-"></a>
##### Raises:
* <b>`TypeError`</b>: If `sp_input` is not a `SparseTensor`.
@ -251,7 +234,7 @@ string tensor with values:
- - -
### `tf.sparse_to_indicator(sp_input, vocab_size, name=None)` <a class="md-anchor" id="sparse_to_indicator"></a>
### `tf.sparse_to_indicator(sp_input, vocab_size, name=None)` {#sparse_to_indicator}
Converts a `SparseTensor` of ids into a dense bool indicator tensor.
@ -282,7 +265,7 @@ compatibility with ops that expect dense tensors.
The input `SparseTensor` must be in row-major order.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`sp_input`</b>: A `SparseTensor` of type `int32` or `int64`.
@ -290,22 +273,22 @@ The input `SparseTensor` must be in row-major order.
`all(0 <= sp_input.values < vocab_size)`.
* <b>`name`</b>: A name prefix for the returned tensors (optional)
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A dense bool indicator tensor representing the indices with specified value.
##### Raises: <a class="md-anchor" id="AUTOGENERATED-raises-"></a>
##### Raises:
* <b>`TypeError`</b>: If `sp_input` is not a `SparseTensor`.
## Manipulation <a class="md-anchor" id="AUTOGENERATED-manipulation"></a>
## Manipulation
- - -
### `tf.sparse_concat(concat_dim, sp_inputs, name=None)` <a class="md-anchor" id="sparse_concat"></a>
### `tf.sparse_concat(concat_dim, sp_inputs, name=None)` {#sparse_concat}
Concatenates a list of `SparseTensor` along the specified dimension.
@ -351,18 +334,18 @@ Graphically this is equivalent to doing
[ a] concat [ d e ] = [ a d e ]
[b c ] [ ] [b c ]
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`concat_dim`</b>: Dimension to concatenate along.
* <b>`sp_inputs`</b>: List of `SparseTensor` to concatenate.
* <b>`name`</b>: A name prefix for the returned tensors (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `SparseTensor` with the concatenated output.
##### Raises: <a class="md-anchor" id="AUTOGENERATED-raises-"></a>
##### Raises:
* <b>`TypeError`</b>: If `sp_inputs` is not a list of `SparseTensor`.
@ -370,7 +353,7 @@ Graphically this is equivalent to doing
- - -
### `tf.sparse_reorder(sp_input, name=None)` <a class="md-anchor" id="sparse_reorder"></a>
### `tf.sparse_reorder(sp_input, name=None)` {#sparse_reorder}
Reorders a `SparseTensor` into the canonical, row-major ordering.
@ -395,18 +378,18 @@ then the output will be a `SparseTensor` of shape `[4, 5]` and
[2, 0]: c
[3, 1]: d
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`sp_input`</b>: The input `SparseTensor`.
* <b>`name`</b>: A name prefix for the returned tensors (optional)
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `SparseTensor` with the same shape and non-empty values, but in
canonical ordering.
##### Raises: <a class="md-anchor" id="AUTOGENERATED-raises-"></a>
##### Raises:
* <b>`TypeError`</b>: If `sp_input` is not a `SparseTensor`.
@ -414,7 +397,7 @@ then the output will be a `SparseTensor` of shape `[4, 5]` and
- - -
### `tf.sparse_retain(sp_input, to_retain)` <a class="md-anchor" id="sparse_retain"></a>
### `tf.sparse_retain(sp_input, to_retain)` {#sparse_retain}
Retains specified non-empty values within a `SparseTensor`.
@ -431,18 +414,18 @@ be a `SparseTensor` of shape `[4, 5]` with 2 non-empty values:
[0, 1]: a
[3, 1]: d
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`sp_input`</b>: The input `SparseTensor` with `N` non-empty elements.
* <b>`to_retain`</b>: A bool vector of length `N` with `M` true values.
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `SparseTensor` with the same shape as the input and `M` non-empty
elements corresponding to the true positions in `to_retain`.
##### Raises: <a class="md-anchor" id="AUTOGENERATED-raises-"></a>
##### Raises:
* <b>`TypeError`</b>: If `sp_input` is not a `SparseTensor`.
@ -450,7 +433,7 @@ be a `SparseTensor` of shape `[4, 5]` with 2 non-empty values:
- - -
### `tf.sparse_fill_empty_rows(sp_input, default_value, name=None)` <a class="md-anchor" id="sparse_fill_empty_rows"></a>
### `tf.sparse_fill_empty_rows(sp_input, default_value, name=None)` {#sparse_fill_empty_rows}
Fills empty rows in the input 2-D `SparseTensor` with a default value.
@ -483,7 +466,7 @@ This op also returns an indicator vector such that
empty_row_indicator[i] = True iff row i was an empty row.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`sp_input`</b>: A `SparseTensor` with shape `[N, M]`.
@ -491,7 +474,7 @@ This op also returns an indicator vector such that
`sp_input.`
* <b>`name`</b>: A name prefix for the returned tensors (optional)
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
* <b>`sp_ordered_output`</b>: A `SparseTensor` with shape `[N, M]`, and with all empty
@ -499,7 +482,7 @@ This op also returns an indicator vector such that
* <b>`empty_row_indicator`</b>: A bool vector of length `N` indicating whether each
input row was empty.
##### Raises: <a class="md-anchor" id="AUTOGENERATED-raises-"></a>
##### Raises:
* <b>`TypeError`</b>: If `sp_input` is not a `SparseTensor`.

View File

@ -1,51 +1,17 @@
<!-- This file is machine generated: DO NOT EDIT! -->
# Variables <a class="md-anchor" id="AUTOGENERATED-variables"></a>
# Variables
Note: Functions taking `Tensor` arguments can also take anything accepted by
[`tf.convert_to_tensor`](../../api_docs/python/framework.md#convert_to_tensor).
<!-- TOC-BEGIN This section is generated by neural network: DO NOT EDIT! -->
## Contents
### [Variables](#AUTOGENERATED-variables)
* [Variables](#AUTOGENERATED-variables)
* [`class tf.Variable`](#Variable)
* [Variable helper functions](#AUTOGENERATED-variable-helper-functions)
* [`tf.all_variables()`](#all_variables)
* [`tf.trainable_variables()`](#trainable_variables)
* [`tf.initialize_all_variables()`](#initialize_all_variables)
* [`tf.initialize_variables(var_list, name='init')`](#initialize_variables)
* [`tf.assert_variables_initialized(var_list=None)`](#assert_variables_initialized)
* [Saving and Restoring Variables](#AUTOGENERATED-saving-and-restoring-variables)
* [`class tf.train.Saver`](#Saver)
* [`tf.train.latest_checkpoint(checkpoint_dir, latest_filename=None)`](#latest_checkpoint)
* [`tf.train.get_checkpoint_state(checkpoint_dir, latest_filename=None)`](#get_checkpoint_state)
* [`tf.train.update_checkpoint_state(save_dir, model_checkpoint_path, all_model_checkpoint_paths=None, latest_filename=None)`](#update_checkpoint_state)
* [Sharing Variables](#AUTOGENERATED-sharing-variables)
* [`tf.get_variable(name, shape=None, dtype=tf.float32, initializer=None, trainable=True, collections=None)`](#get_variable)
* [`tf.get_variable_scope()`](#get_variable_scope)
* [`tf.variable_scope(name_or_scope, reuse=None, initializer=None)`](#variable_scope)
* [`tf.constant_initializer(value=0.0)`](#constant_initializer)
* [`tf.random_normal_initializer(mean=0.0, stddev=1.0, seed=None)`](#random_normal_initializer)
* [`tf.truncated_normal_initializer(mean=0.0, stddev=1.0, seed=None)`](#truncated_normal_initializer)
* [`tf.random_uniform_initializer(minval=0.0, maxval=1.0, seed=None)`](#random_uniform_initializer)
* [`tf.uniform_unit_scaling_initializer(factor=1.0, seed=None)`](#uniform_unit_scaling_initializer)
* [`tf.zeros_initializer(shape, dtype=tf.float32)`](#zeros_initializer)
* [Sparse Variable Updates](#AUTOGENERATED-sparse-variable-updates)
* [`tf.scatter_update(ref, indices, updates, use_locking=None, name=None)`](#scatter_update)
* [`tf.scatter_add(ref, indices, updates, use_locking=None, name=None)`](#scatter_add)
* [`tf.scatter_sub(ref, indices, updates, use_locking=None, name=None)`](#scatter_sub)
* [`tf.sparse_mask(a, mask_indices, name=None)`](#sparse_mask)
* [`class tf.IndexedSlices`](#IndexedSlices)
[TOC]
<!-- TOC-END This section was generated by neural network, THANKS FOR READING! -->
## Variables <a class="md-anchor" id="AUTOGENERATED-variables"></a>
## Variables
- - -
### `class tf.Variable` <a class="md-anchor" id="Variable"></a>
### `class tf.Variable` {#Variable}
See the [Variables How To](../../how_tos/variables/index.md) for a high
level overview.
@ -138,7 +104,7 @@ Creating a variable.
- - -
#### `tf.Variable.__init__(initial_value, trainable=True, collections=None, validate_shape=True, name=None)` <a class="md-anchor" id="Variable.__init__"></a>
#### `tf.Variable.__init__(initial_value, trainable=True, collections=None, validate_shape=True, name=None)` {#Variable.__init__}
Creates a new variable with value `initial_value`.
@ -151,7 +117,7 @@ If `trainable` is `True` the variable is also added to the graph collection
This constructor creates both a `variable` Op and an `assign` Op to set the
variable to its initial value.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`initial_value`</b>: A `Tensor`, or Python object convertible to a `Tensor`.
@ -168,11 +134,11 @@ variable to its initial value.
* <b>`name`</b>: Optional name for the variable. Defaults to `'Variable'` and gets
uniquified automatically.
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A Variable.
##### Raises: <a class="md-anchor" id="AUTOGENERATED-raises-"></a>
##### Raises:
* <b>`ValueError`</b>: If the initial value does not have a shape and
@ -181,7 +147,7 @@ variable to its initial value.
- - -
#### `tf.Variable.initialized_value()` <a class="md-anchor" id="Variable.initialized_value"></a>
#### `tf.Variable.initialized_value()` {#Variable.initialized_value}
Returns the value of the initialized variable.
@ -197,7 +163,7 @@ v = tf.Variable(tf.truncated_normal([10, 40]))
w = tf.Variable(v.initialized_value() * 2.0)
```
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor` holding the value of this variable after its initializer
has run.
@ -208,19 +174,19 @@ Changing a variable value.
- - -
#### `tf.Variable.assign(value, use_locking=False)` <a class="md-anchor" id="Variable.assign"></a>
#### `tf.Variable.assign(value, use_locking=False)` {#Variable.assign}
Assigns a new value to the variable.
This is essentially a shortcut for `assign(self, value)`.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`value`</b>: A `Tensor`. The new value for this variable.
* <b>`use_locking`</b>: If `True`, use locking during the assignment.
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor` that will hold the new value of this variable after
the assignment has completed.
@ -228,19 +194,19 @@ This is essentially a shortcut for `assign(self, value)`.
- - -
#### `tf.Variable.assign_add(delta, use_locking=False)` <a class="md-anchor" id="Variable.assign_add"></a>
#### `tf.Variable.assign_add(delta, use_locking=False)` {#Variable.assign_add}
Adds a value to this variable.
This is essentially a shortcut for `assign_add(self, delta)`.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`delta`</b>: A `Tensor`. The value to add to this variable.
* <b>`use_locking`</b>: If `True`, use locking during the operation.
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor` that will hold the new value of this variable after
the addition has completed.
@ -248,19 +214,19 @@ Adds a value to this variable.
- - -
#### `tf.Variable.assign_sub(delta, use_locking=False)` <a class="md-anchor" id="Variable.assign_sub"></a>
#### `tf.Variable.assign_sub(delta, use_locking=False)` {#Variable.assign_sub}
Subtracts a value from this variable.
This is essentially a shortcut for `assign_sub(self, delta)`.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`delta`</b>: A `Tensor`. The value to subtract from this variable.
* <b>`use_locking`</b>: If `True`, use locking during the operation.
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor` that will hold the new value of this variable after
the subtraction has completed.
@ -268,25 +234,25 @@ This is essentially a shortcut for `assign_sub(self, delta)`.
- - -
#### `tf.Variable.scatter_sub(sparse_delta, use_locking=False)` <a class="md-anchor" id="Variable.scatter_sub"></a>
#### `tf.Variable.scatter_sub(sparse_delta, use_locking=False)` {#Variable.scatter_sub}
Subtracts `IndexedSlices` from this variable.
This is essentially a shortcut for `scatter_sub(self, sparse_delta.indices,
sparse_delta.values)`.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`sparse_delta`</b>: `IndexedSlices` to be subtracted from this variable.
* <b>`use_locking`</b>: If `True`, use locking during the operation.
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor` that will hold the new value of this variable after
the scattered subtraction has completed.
##### Raises: <a class="md-anchor" id="AUTOGENERATED-raises-"></a>
##### Raises:
* <b>`ValueError`</b>: if `sparse_delta` is not an `IndexedSlices`.
@ -294,7 +260,7 @@ sparse_delta.values)`.
- - -
#### `tf.Variable.count_up_to(limit)` <a class="md-anchor" id="Variable.count_up_to"></a>
#### `tf.Variable.count_up_to(limit)` {#Variable.count_up_to}
Increments this variable until it reaches `limit`.
@ -307,12 +273,12 @@ the increment.
This is essentially a shortcut for `count_up_to(self, limit)`.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`limit`</b>: value at which incrementing the variable raises an error.
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `Tensor` that will hold the variable value before the increment. If no
other Op modifies this variable, the values produced will all be
@ -322,7 +288,7 @@ This is essentially a shortcut for `count_up_to(self, limit)`.
- - -
#### `tf.Variable.eval(session=None)` <a class="md-anchor" id="Variable.eval"></a>
#### `tf.Variable.eval(session=None)` {#Variable.eval}
In a session, computes and returns the value of this variable.
@ -346,13 +312,13 @@ with tf.Session() as sess:
print v.eval()
```
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`session`</b>: The session to use to evaluate this variable. If
none, the default session is used.
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A numpy `ndarray` with a copy of the value of this variable.
@ -362,61 +328,61 @@ Properties.
- - -
#### `tf.Variable.name` <a class="md-anchor" id="Variable.name"></a>
#### `tf.Variable.name` {#Variable.name}
The name of this variable.
- - -
#### `tf.Variable.dtype` <a class="md-anchor" id="Variable.dtype"></a>
#### `tf.Variable.dtype` {#Variable.dtype}
The `DType` of this variable.
- - -
#### `tf.Variable.get_shape()` <a class="md-anchor" id="Variable.get_shape"></a>
#### `tf.Variable.get_shape()` {#Variable.get_shape}
The `TensorShape` of this variable.
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `TensorShape`.
- - -
#### `tf.Variable.device` <a class="md-anchor" id="Variable.device"></a>
#### `tf.Variable.device` {#Variable.device}
The device of this variable.
- - -
#### `tf.Variable.initializer` <a class="md-anchor" id="Variable.initializer"></a>
#### `tf.Variable.initializer` {#Variable.initializer}
The initializer operation for this variable.
- - -
#### `tf.Variable.graph` <a class="md-anchor" id="Variable.graph"></a>
#### `tf.Variable.graph` {#Variable.graph}
The `Graph` of this variable.
- - -
#### `tf.Variable.op` <a class="md-anchor" id="Variable.op"></a>
#### `tf.Variable.op` {#Variable.op}
The `Operation` of this variable.
## Variable helper functions <a class="md-anchor" id="AUTOGENERATED-variable-helper-functions"></a>
## Variable helper functions
TensorFlow provides a set of functions to help manage the set of variables
collected in the graph.
- - -
### `tf.all_variables()` <a class="md-anchor" id="all_variables"></a>
### `tf.all_variables()` {#all_variables}
Returns all variables collected in the graph.
@ -424,14 +390,14 @@ The `Variable()` constructor automatically adds new variables to the graph
collection `GraphKeys.VARIABLES`. This convenience function returns the
contents of that collection.
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A list of `Variable` objects.
- - -
### `tf.trainable_variables()` <a class="md-anchor" id="trainable_variables"></a>
### `tf.trainable_variables()` {#trainable_variables}
Returns all variables created with `trainable=True`.
@ -440,7 +406,7 @@ adds new variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES`. This convenience function returns the
contents of that collection.
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A list of Variable objects.
@ -448,20 +414,20 @@ contents of that collection.
- - -
### `tf.initialize_all_variables()` <a class="md-anchor" id="initialize_all_variables"></a>
### `tf.initialize_all_variables()` {#initialize_all_variables}
Returns an Op that initializes all variables.
This is just a shortcut for `initialize_variables(all_variables())`
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
An Op that initializes all variables in the graph.
- - -
### `tf.initialize_variables(var_list, name='init')` <a class="md-anchor" id="initialize_variables"></a>
### `tf.initialize_variables(var_list, name='init')` {#initialize_variables}
Returns an Op that initializes a list of variables.
@ -475,20 +441,20 @@ initializers to `Group()`.
If `var_list` is empty, however, the function still returns an Op that can
be run. That Op just has no effect.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`var_list`</b>: List of `Variable` objects to initialize.
* <b>`name`</b>: Optional name for the returned operation.
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
An Op that run the initializers of all the specified variables.
- - -
### `tf.assert_variables_initialized(var_list=None)` <a class="md-anchor" id="assert_variables_initialized"></a>
### `tf.assert_variables_initialized(var_list=None)` {#assert_variables_initialized}
Returns an Op to check if variables are initialized.
@ -499,23 +465,23 @@ Note: This function is implemented by trying to fetch the values of the
variables. If one of the variables is not initialized a message may be
logged by the C++ runtime. This is expected.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`var_list`</b>: List of `Variable` objects to check. Defaults to the
value of `all_variables().`
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
An Op, or None if there are no variables.
## Saving and Restoring Variables <a class="md-anchor" id="AUTOGENERATED-saving-and-restoring-variables"></a>
## Saving and Restoring Variables
- - -
### `class tf.train.Saver` <a class="md-anchor" id="Saver"></a>
### `class tf.train.Saver` {#Saver}
Saves and restores variables.
@ -591,7 +557,7 @@ protocol buffer file in the call to `save()`.
- - -
#### `tf.train.Saver.__init__(var_list=None, reshape=False, sharded=False, max_to_keep=5, keep_checkpoint_every_n_hours=10000.0, name=None, restore_sequentially=False, saver_def=None, builder=None)` <a class="md-anchor" id="Saver.__init__"></a>
#### `tf.train.Saver.__init__(var_list=None, reshape=False, sharded=False, max_to_keep=5, keep_checkpoint_every_n_hours=10000.0, name=None, restore_sequentially=False, saver_def=None, builder=None)` {#Saver.__init__}
Creates a `Saver`.
@ -629,7 +595,7 @@ want to reload it from an older checkpoint.
The optional `sharded` argument, if True, instructs the saver to shard
checkpoints per device.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`var_list`</b>: A list of Variables or a dictionary mapping names to
@ -653,7 +619,7 @@ checkpoints per device.
* <b>`builder`</b>: Optional SaverBuilder to use if a saver_def was not provided.
Defaults to BaseSaverBuilder().
##### Raises: <a class="md-anchor" id="AUTOGENERATED-raises-"></a>
##### Raises:
* <b>`TypeError`</b>: If `var_list` is invalid.
@ -662,7 +628,7 @@ checkpoints per device.
- - -
#### `tf.train.Saver.save(sess, save_path, global_step=None, latest_filename=None)` <a class="md-anchor" id="Saver.save"></a>
#### `tf.train.Saver.save(sess, save_path, global_step=None, latest_filename=None)` {#Saver.save}
Saves variables.
@ -673,7 +639,7 @@ save must also have been initialized.
The method returns the path of the newly created checkpoint file. This
path can be passed directly to a call to `restore()`.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`sess`</b>: A Session to use to save the variables.
@ -688,13 +654,13 @@ path can be passed directly to a call to `restore()`.
managed by the saver to keep track of recent checkpoints. Defaults to
'checkpoint'.
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A string: path at which the variables were saved. If the saver is
sharded, this string ends with: '-?????-of-nnnnn' where 'nnnnn'
is the number of shards created.
##### Raises: <a class="md-anchor" id="AUTOGENERATED-raises-"></a>
##### Raises:
* <b>`TypeError`</b>: If `sess` is not a Session.
@ -702,7 +668,7 @@ path can be passed directly to a call to `restore()`.
- - -
#### `tf.train.Saver.restore(sess, save_path)` <a class="md-anchor" id="Saver.restore"></a>
#### `tf.train.Saver.restore(sess, save_path)` {#Saver.restore}
Restores previously saved variables.
@ -714,7 +680,7 @@ to initialize variables.
The `save_path` argument is typically a value previously returned from a
`save()` call, or a call to `latest_checkpoint()`.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`sess`</b>: A Session to use to restore the parameters.
@ -726,28 +692,28 @@ Other utility methods.
- - -
#### `tf.train.Saver.last_checkpoints` <a class="md-anchor" id="Saver.last_checkpoints"></a>
#### `tf.train.Saver.last_checkpoints` {#Saver.last_checkpoints}
List of not-yet-deleted checkpoint filenames.
You can pass any of the returned values to `restore()`.
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A list of checkpoint filenames, sorted from oldest to newest.
- - -
#### `tf.train.Saver.set_last_checkpoints(last_checkpoints)` <a class="md-anchor" id="Saver.set_last_checkpoints"></a>
#### `tf.train.Saver.set_last_checkpoints(last_checkpoints)` {#Saver.set_last_checkpoints}
Sets the list of not-yet-deleted checkpoint filenames.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`last_checkpoints`</b>: a list of checkpoint filenames.
##### Raises: <a class="md-anchor" id="AUTOGENERATED-raises-"></a>
##### Raises:
* <b>`AssertionError`</b>: if the list of checkpoint filenames has already been set.
@ -755,11 +721,11 @@ Sets the list of not-yet-deleted checkpoint filenames.
- - -
#### `tf.train.Saver.as_saver_def()` <a class="md-anchor" id="Saver.as_saver_def"></a>
#### `tf.train.Saver.as_saver_def()` {#Saver.as_saver_def}
Generates a `SaverDef` representation of this saver.
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A `SaverDef` proto.
@ -768,11 +734,11 @@ Generates a `SaverDef` representation of this saver.
- - -
### `tf.train.latest_checkpoint(checkpoint_dir, latest_filename=None)` <a class="md-anchor" id="latest_checkpoint"></a>
### `tf.train.latest_checkpoint(checkpoint_dir, latest_filename=None)` {#latest_checkpoint}
Finds the filename of latest saved checkpoint file.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`checkpoint_dir`</b>: Directory where the variables were saved.
@ -780,7 +746,7 @@ Finds the filename of latest saved checkpoint file.
contains the list of most recent checkpoint filenames.
See the corresponding argument to `Saver.save()`.
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
The full path to the latest checkpoint or None if no checkpoint was found.
@ -788,21 +754,21 @@ Finds the filename of latest saved checkpoint file.
- - -
### `tf.train.get_checkpoint_state(checkpoint_dir, latest_filename=None)` <a class="md-anchor" id="get_checkpoint_state"></a>
### `tf.train.get_checkpoint_state(checkpoint_dir, latest_filename=None)` {#get_checkpoint_state}
Returns CheckpointState proto from the "checkpoint" file.
If the "checkpoint" file contains a valid CheckpointState
proto, returns it.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`checkpoint_dir`</b>: The directory of checkpoints.
* <b>`latest_filename`</b>: Optional name of the checkpoint file. Default to
'checkpoint'.
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
A CheckpointState if the state was available, None
otherwise.
@ -810,14 +776,14 @@ proto, returns it.
- - -
### `tf.train.update_checkpoint_state(save_dir, model_checkpoint_path, all_model_checkpoint_paths=None, latest_filename=None)` <a class="md-anchor" id="update_checkpoint_state"></a>
### `tf.train.update_checkpoint_state(save_dir, model_checkpoint_path, all_model_checkpoint_paths=None, latest_filename=None)` {#update_checkpoint_state}
Updates the content of the 'checkpoint' file.
This updates the checkpoint file containing a CheckpointState
proto.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`save_dir`</b>: Directory where the model was saved.
@ -829,21 +795,21 @@ proto.
* <b>`latest_filename`</b>: Optional name of the checkpoint file. Default to
'checkpoint'.
##### Raises: <a class="md-anchor" id="AUTOGENERATED-raises-"></a>
##### Raises:
* <b>`RuntimeError`</b>: If the save paths conflict.
## Sharing Variables <a class="md-anchor" id="AUTOGENERATED-sharing-variables"></a>
## Sharing Variables
TensorFlow provides several classes and operations that you can use to
create variables contingent on certain conditions.
- - -
### `tf.get_variable(name, shape=None, dtype=tf.float32, initializer=None, trainable=True, collections=None)` <a class="md-anchor" id="get_variable"></a>
### `tf.get_variable(name, shape=None, dtype=tf.float32, initializer=None, trainable=True, collections=None)` {#get_variable}
Gets an existing variable with these parameters or create a new one.
@ -864,7 +830,7 @@ If initializer is `None` (the default), the default initializer passed in
the constructor is used. If that one is `None` too, a
`UniformUnitScalingInitializer` will be used.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`name`</b>: the name of the new or existing variable.
@ -876,11 +842,11 @@ the constructor is used. If that one is `None` too, a
* <b>`collections`</b>: List of graph collections keys to add the Variable to.
Defaults to `[GraphKeys.VARIABLES]` (see variables.Variable).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
The created or existing variable.
##### Raises: <a class="md-anchor" id="AUTOGENERATED-raises-"></a>
##### Raises:
* <b>`ValueError`</b>: when creating a new variable and shape is not declared,
@ -890,14 +856,14 @@ the constructor is used. If that one is `None` too, a
- - -
### `tf.get_variable_scope()` <a class="md-anchor" id="get_variable_scope"></a>
### `tf.get_variable_scope()` {#get_variable_scope}
Returns the current variable scope.
- - -
### `tf.variable_scope(name_or_scope, reuse=None, initializer=None)` <a class="md-anchor" id="variable_scope"></a>
### `tf.variable_scope(name_or_scope, reuse=None, initializer=None)` {#variable_scope}
Returns a context for variable scope.
@ -957,7 +923,7 @@ with tf.variable_scope("foo", reuse=True):
Note that the `reuse` flag is inherited: if we open a reusing scope,
then all its sub-scopes become reusing as well.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`name_or_scope`</b>: `string` or `VariableScope`: the scope to open.
@ -965,11 +931,11 @@ then all its sub-scopes become reusing as well.
well as all sub-scopes; if `None`, we just inherit the parent scope reuse.
* <b>`initializer`</b>: default initializer for variables within this scope.
##### Yields: <a class="md-anchor" id="AUTOGENERATED-yields-"></a>
##### Yields:
A scope that can be to captured and reused.
##### Raises: <a class="md-anchor" id="AUTOGENERATED-raises-"></a>
##### Raises:
* <b>`ValueError`</b>: when trying to reuse within a create scope, or create within
@ -980,28 +946,28 @@ then all its sub-scopes become reusing as well.
- - -
### `tf.constant_initializer(value=0.0)` <a class="md-anchor" id="constant_initializer"></a>
### `tf.constant_initializer(value=0.0)` {#constant_initializer}
Returns an initializer that generates Tensors with a single value.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`value`</b>: A Python scalar. All elements of the initialized variable
will be set to this value.
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
An initializer that generates Tensors with a single value.
- - -
### `tf.random_normal_initializer(mean=0.0, stddev=1.0, seed=None)` <a class="md-anchor" id="random_normal_initializer"></a>
### `tf.random_normal_initializer(mean=0.0, stddev=1.0, seed=None)` {#random_normal_initializer}
Returns an initializer that generates Tensors with a normal distribution.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`mean`</b>: a python scalar or a scalar tensor. Mean of the random values
@ -1012,14 +978,14 @@ Returns an initializer that generates Tensors with a normal distribution.
[`set_random_seed`](../../api_docs/python/constant_op.md#set_random_seed)
for behavior.
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
An initializer that generates Tensors with a normal distribution.
- - -
### `tf.truncated_normal_initializer(mean=0.0, stddev=1.0, seed=None)` <a class="md-anchor" id="truncated_normal_initializer"></a>
### `tf.truncated_normal_initializer(mean=0.0, stddev=1.0, seed=None)` {#truncated_normal_initializer}
Returns an initializer that generates a truncated normal distribution.
@ -1028,7 +994,7 @@ except that values more than two standard deviations from the mean
are discarded and re-drawn. This is the recommended initializer for
neural network weights and filters.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`mean`</b>: a python scalar or a scalar tensor. Mean of the random values
@ -1039,7 +1005,7 @@ neural network weights and filters.
[`set_random_seed`](../../api_docs/python/constant_op.md#set_random_seed)
for behavior.
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
An initializer that generates Tensors with a truncated normal
distribution.
@ -1047,11 +1013,11 @@ neural network weights and filters.
- - -
### `tf.random_uniform_initializer(minval=0.0, maxval=1.0, seed=None)` <a class="md-anchor" id="random_uniform_initializer"></a>
### `tf.random_uniform_initializer(minval=0.0, maxval=1.0, seed=None)` {#random_uniform_initializer}
Returns an initializer that generates Tensors with a uniform distribution.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`minval`</b>: a python scalar or a scalar tensor. lower bound of the range
@ -1062,14 +1028,14 @@ Returns an initializer that generates Tensors with a uniform distribution.
[`set_random_seed`](../../api_docs/python/constant_op.md#set_random_seed)
for behavior.
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
An initializer that generates Tensors with a uniform distribution.
- - -
### `tf.uniform_unit_scaling_initializer(factor=1.0, seed=None)` <a class="md-anchor" id="uniform_unit_scaling_initializer"></a>
### `tf.uniform_unit_scaling_initializer(factor=1.0, seed=None)` {#uniform_unit_scaling_initializer}
Returns an initializer that generates tensors without scaling variance.
@ -1088,7 +1054,7 @@ See <https://arxiv.org/pdf/1412.6558v3.pdf> for deeper motivation, experiments
and the calculation of constants. In section 2.3 there, the constants were
numerically computed: for a linear layer it's 1.0, relu: ~1.43, tanh: ~1.15.
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`factor`</b>: Float. A multiplicative factor by which the values will be scaled.
@ -1096,20 +1062,20 @@ numerically computed: for a linear layer it's 1.0, relu: ~1.43, tanh: ~1.15.
[`set_random_seed`](../../api_docs/python/constant_op.md#set_random_seed)
for behavior.
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
An initializer that generates tensors with unit variance.
- - -
### `tf.zeros_initializer(shape, dtype=tf.float32)` <a class="md-anchor" id="zeros_initializer"></a>
### `tf.zeros_initializer(shape, dtype=tf.float32)` {#zeros_initializer}
An adaptor for zeros() to match the Initializer spec.
## Sparse Variable Updates <a class="md-anchor" id="AUTOGENERATED-sparse-variable-updates"></a>
## Sparse Variable Updates
The sparse update ops modify a subset of the entries in a dense `Variable`,
either overwriting the entries or adding / subtracting a delta. These are
@ -1125,7 +1091,7 @@ automatically by the optimizers in most cases.
- - -
### `tf.scatter_update(ref, indices, updates, use_locking=None, name=None)` <a class="md-anchor" id="scatter_update"></a>
### `tf.scatter_update(ref, indices, updates, use_locking=None, name=None)` {#scatter_update}
Applies sparse updates to a variable reference.
@ -1152,7 +1118,7 @@ Requires `updates.shape = indices.shape + ref.shape[1:]`.
<img style="width:100%" src="../images/ScatterUpdate.png" alt>
</div>
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`ref`</b>: A mutable `Tensor`. Should be from a `Variable` node.
@ -1165,7 +1131,7 @@ Requires `updates.shape = indices.shape + ref.shape[1:]`.
otherwise the behavior is undefined, but may exhibit less contention.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
Same as `ref`. Returned as a convenience for operations that want
to use the updated values after the update is done.
@ -1173,7 +1139,7 @@ Requires `updates.shape = indices.shape + ref.shape[1:]`.
- - -
### `tf.scatter_add(ref, indices, updates, use_locking=None, name=None)` <a class="md-anchor" id="scatter_add"></a>
### `tf.scatter_add(ref, indices, updates, use_locking=None, name=None)` {#scatter_add}
Adds sparse updates to a variable reference.
@ -1200,7 +1166,7 @@ Requires `updates.shape = indices.shape + ref.shape[1:]`.
<img style="width:100%" src="../images/ScatterAdd.png" alt>
</div>
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`ref`</b>: A mutable `Tensor`. Must be one of the following types: `float32`, `float64`, `int64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `qint8`, `quint8`, `qint32`.
@ -1214,7 +1180,7 @@ Requires `updates.shape = indices.shape + ref.shape[1:]`.
otherwise the behavior is undefined, but may exhibit less contention.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
Same as `ref`. Returned as a convenience for operations that want
to use the updated values after the update is done.
@ -1222,7 +1188,7 @@ Requires `updates.shape = indices.shape + ref.shape[1:]`.
- - -
### `tf.scatter_sub(ref, indices, updates, use_locking=None, name=None)` <a class="md-anchor" id="scatter_sub"></a>
### `tf.scatter_sub(ref, indices, updates, use_locking=None, name=None)` {#scatter_sub}
Subtracts sparse updates to a variable reference.
@ -1247,7 +1213,7 @@ Requires `updates.shape = indices.shape + ref.shape[1:]`.
<img style="width:100%" src="../images/ScatterSub.png" alt>
</div>
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* <b>`ref`</b>: A mutable `Tensor`. Must be one of the following types: `float32`, `float64`, `int64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `qint8`, `quint8`, `qint32`.
@ -1261,7 +1227,7 @@ Requires `updates.shape = indices.shape + ref.shape[1:]`.
otherwise the behavior is undefined, but may exhibit less contention.
* <b>`name`</b>: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
Same as `ref`. Returned as a convenience for operations that want
to use the updated values after the update is done.
@ -1269,7 +1235,7 @@ Requires `updates.shape = indices.shape + ref.shape[1:]`.
- - -
### `tf.sparse_mask(a, mask_indices, name=None)` <a class="md-anchor" id="sparse_mask"></a>
### `tf.sparse_mask(a, mask_indices, name=None)` {#sparse_mask}
Masks elements of `IndexedSlices`.
@ -1298,20 +1264,20 @@ tf.shape(b.values) => [2, 10]
```
##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
##### Args:
* `a`: An `IndexedSlices` instance.
* `mask_indices`: Indices of elements to mask.
* `name`: A name for the operation (optional).
##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
##### Returns:
The masked `IndexedSlices` instance.
- - -
### `class tf.IndexedSlices` <a class="md-anchor" id="IndexedSlices"></a>
### `class tf.IndexedSlices` {#IndexedSlices}
A sparse representation of a set of tensor slices at given indices.
@ -1341,7 +1307,7 @@ which uses multi-dimensional indices and scalar values.
- - -
#### `tf.IndexedSlices.__init__(values, indices, dense_shape=None)` <a class="md-anchor" id="IndexedSlices.__init__"></a>
#### `tf.IndexedSlices.__init__(values, indices, dense_shape=None)` {#IndexedSlices.__init__}
Creates an `IndexedSlices`.
@ -1349,44 +1315,44 @@ Creates an `IndexedSlices`.
- - -
#### `tf.IndexedSlices.values` <a class="md-anchor" id="IndexedSlices.values"></a>
#### `tf.IndexedSlices.values` {#IndexedSlices.values}
A `Tensor` containing the values of the slices.
- - -
#### `tf.IndexedSlices.indices` <a class="md-anchor" id="IndexedSlices.indices"></a>
#### `tf.IndexedSlices.indices` {#IndexedSlices.indices}
A 1-D `Tensor` containing the indices of the slices.
- - -
#### `tf.IndexedSlices.dense_shape` <a class="md-anchor" id="IndexedSlices.dense_shape"></a>
#### `tf.IndexedSlices.dense_shape` {#IndexedSlices.dense_shape}
A 1-D `Tensor` containing the shape of the corresponding dense tensor.
- - -
#### `tf.IndexedSlices.name` <a class="md-anchor" id="IndexedSlices.name"></a>
#### `tf.IndexedSlices.name` {#IndexedSlices.name}
The name of this `IndexedSlices`.
- - -
#### `tf.IndexedSlices.dtype` <a class="md-anchor" id="IndexedSlices.dtype"></a>
#### `tf.IndexedSlices.dtype` {#IndexedSlices.dtype}
The `DType` of elements in this tensor.
- - -
#### `tf.IndexedSlices.device` <a class="md-anchor" id="IndexedSlices.device"></a>
#### `tf.IndexedSlices.device` {#IndexedSlices.device}
The name of the device on which `values` will be produced, or `None`.
- - -
#### `tf.IndexedSlices.op` <a class="md-anchor" id="IndexedSlices.op"></a>
#### `tf.IndexedSlices.op` {#IndexedSlices.op}
The `Operation` that produces `values` as an output.

File diff suppressed because it is too large Load Diff

View File

@ -1,4 +1,4 @@
# Basic Usage <a class="md-anchor" id="AUTOGENERATED-basic-usage"></a>
# Basic Usage
To use TensorFlow you need to understand how TensorFlow:
@ -8,7 +8,7 @@ To use TensorFlow you need to understand how TensorFlow:
* Maintains state with `Variables`.
* Uses feeds and fetches to get data into and out of arbitrary operations.
## Overview <a class="md-anchor" id="AUTOGENERATED-overview"></a>
## Overview
TensorFlow is a programming system in which you represent computations as
graphs. Nodes in the graph are called *ops* (short for operations). An op
@ -24,7 +24,7 @@ methods return tensors produced by ops as [numpy](http://www.numpy.org)
`ndarray` objects in Python, and as `tensorflow::Tensor` instances in C and
C++.
## The computation graph <a class="md-anchor" id="AUTOGENERATED-the-computation-graph"></a>
## The computation graph
TensorFlow programs are usually structured into a construction phase, that
assembles a graph, and an execution phase that uses a session to execute ops in
@ -40,7 +40,7 @@ of helper functions not available in the C and C++ libraries.
The session libraries have equivalent functionalities for the three languages.
### Building the graph <a class="md-anchor" id="AUTOGENERATED-building-the-graph"></a>
### Building the graph
To build a graph start with ops that do not need any input (source ops), such as
`Constant`, and pass their output to other ops that do computation.
@ -77,7 +77,7 @@ The default graph now has three nodes: two `constant()` ops and one `matmul()`
op. To actually multiply the matrices, and get the result of the multiplication,
you must launch the graph in a session.
### Launching the graph in a session <a class="md-anchor" id="AUTOGENERATED-launching-the-graph-in-a-session"></a>
### Launching the graph in a session
Launching follows construction. To launch a graph, create a `Session` object.
Without arguments the session constructor launches the default graph.
@ -146,7 +146,7 @@ Devices are specified with strings. The currently supported devices are:
See [Using GPUs](../how_tos/using_gpu/index.md) for more information about GPUs
and TensorFlow.
## Interactive Usage <a class="md-anchor" id="AUTOGENERATED-interactive-usage"></a>
## Interactive Usage
The Python examples in the documentation launch the graph with a
[`Session`](../api_docs/python/client.md#Session) and use the
@ -171,13 +171,13 @@ a = tf.constant([3.0, 3.0])
# Initialize 'x' using the run() method of its initializer op.
x.initializer.run()
# Add an op to subtact 'a' from 'x'. Run it and print the result
# Add an op to subtract 'a' from 'x'. Run it and print the result
sub = tf.sub(x, a)
print sub.eval()
# ==> [-2. -1.]
```
## Tensors <a class="md-anchor" id="AUTOGENERATED-tensors"></a>
## Tensors
TensorFlow programs use a tensor data structure to represent all data -- only
tensors are passed between operations in the computation graph. You can think
@ -186,7 +186,7 @@ static type a rank, and a shape. To learn more about how TensorFlow handles
these concepts, see the [Rank, Shape, and Type](../resources/dims_types.md)
reference.
## Variables <a class="md-anchor" id="AUTOGENERATED-variables"></a>
## Variables
Variables maintain state across executions of the graph. The following example
shows a variable serving as a simple counter. See
@ -235,7 +235,7 @@ Variables. For example, you would store the weights for a neural network as a
tensor in a Variable. During training you update this tensor by running a
training graph repeatedly.
## Fetches <a class="md-anchor" id="AUTOGENERATED-fetches"></a>
## Fetches
To fetch the outputs of operations, execute the graph with a `run()` call on
the `Session` object and pass in the tensors to retrieve. In the previous
@ -260,7 +260,7 @@ with tf.Session() as sess:
All the ops needed to produce the values of the requested tensors are run once
(not once per requested tensor).
## Feeds <a class="md-anchor" id="AUTOGENERATED-feeds"></a>
## Feeds
The examples above introduce tensors into the computation graph by storing them
in `Constants` and `Variables`. TensorFlow also provides a feed mechanism for

View File

@ -1,4 +1,4 @@
# Introduction <a class="md-anchor" id="AUTOGENERATED-introduction"></a>
# Introduction
Let's get you up and running with TensorFlow!
@ -43,6 +43,10 @@ for step in xrange(0, 201):
# Learns best fit is W: [[0.100 0.200]], b: [0.300]
```
The first part of this code builds the data flow graph. TensorFlow does not
actually run any computation until the session is created and the `run`
function is called.
To whet your appetite further, we suggest you check out what a classical
machine learning problem looks like in TensorFlow. In the land of neural
networks the most "classic" classical problem is the MNIST handwritten digit
@ -67,7 +71,7 @@ these and charge ahead. Don't worry, you'll still get to see MNIST -- we'll
also use MNIST as an example in our technical tutorial where we elaborate on
TensorFlow features.
## Recommended Next Steps: <a class="md-anchor" id="AUTOGENERATED-recommended-next-steps-"></a>
## Recommended Next Steps:
* [Download and Setup](../get_started/os_setup.md)
* [Basic Usage](../get_started/basic_usage.md)
* [TensorFlow Mechanics 101](../tutorials/mnist/tf/index.md)

View File

@ -1,8 +1,8 @@
# Download and Setup <a class="md-anchor" id="AUTOGENERATED-download-and-setup"></a>
# Download and Setup
You can install TensorFlow using our provided binary packages or from source.
## Binary Installation <a class="md-anchor" id="AUTOGENERATED-binary-installation"></a>
## Binary Installation
The TensorFlow Python API currently requires Python 2.7: we are
[working](https://github.com/tensorflow/tensorflow/issues/1) on adding support
@ -16,7 +16,7 @@ If you encounter installation errors, see
installation, please consider using our virtualenv-based instructions
[here](#virtualenv_install).
### Ubuntu/Linux 64-bit <a class="md-anchor" id="AUTOGENERATED-ubuntu-linux-64-bit"></a>
### Ubuntu/Linux 64-bit
```bash
# For CPU-only version
@ -26,7 +26,7 @@ $ pip install https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.5
$ pip install https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.5.0-cp27-none-linux_x86_64.whl
```
### Mac OS X <a class="md-anchor" id="AUTOGENERATED-mac-os-x"></a>
### Mac OS X
On OS X, we recommend installing [homebrew](http://brew.sh) and `brew install
python` before proceeding, or installing TensorFlow within [virtualenv](#virtualenv_install).
@ -36,7 +36,7 @@ python` before proceeding, or installing TensorFlow within [virtualenv](#virtual
$ pip install https://storage.googleapis.com/tensorflow/mac/tensorflow-0.5.0-py2-none-any.whl
```
## Docker-based installation <a class="md-anchor" id="AUTOGENERATED-docker-based-installation"></a>
## Docker-based installation
We also support running TensorFlow via [Docker](http://docker.com/), which lets
you avoid worrying about setting up dependencies.
@ -51,7 +51,7 @@ $ docker run -it b.gcr.io/tensorflow/tensorflow
This will start a container with TensorFlow and all its dependencies already
installed.
### Additional images <a class="md-anchor" id="AUTOGENERATED-additional-images"></a>
### Additional images
The default Docker image above contains just a minimal set of libraries for
getting up and running with TensorFlow. We also have the following container,
@ -62,7 +62,7 @@ which you can use in the `docker run` command above:
makes it easy to experiment directly with the source, without needing to
install any of the dependencies described above.
## VirtualEnv-based installation <a class="md-anchor" id="virtualenv_install"></a>
## VirtualEnv-based installation {#virtualenv_install}
We recommend using [virtualenv](https://pypi.python.org/pypi/virtualenv) to
create an isolated container and install TensorFlow in that container -- it is
@ -121,9 +121,9 @@ then run an example TensorFlow program like:
$ # Your prompt should change back
```
## Try your first TensorFlow program <a class="md-anchor" id="AUTOGENERATED-try-your-first-tensorflow-program"></a>
## Try your first TensorFlow program
### (Optional) Enable GPU Support <a class="md-anchor" id="AUTOGENERATED--optional--enable-gpu-support"></a>
### (Optional) Enable GPU Support
If you installed the GPU-enabled TensorFlow pip binary, you must have the
correct versions of the CUDA SDK and CUDNN installed on your
@ -138,7 +138,7 @@ export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/cuda/lib64"
export CUDA_HOME=/usr/local/cuda
```
### Run TensorFlow <a class="md-anchor" id="AUTOGENERATED-run-tensorflow"></a>
### Run TensorFlow
Open a python terminal:
@ -158,9 +158,9 @@ Hello, TensorFlow!
```
## Installing from sources <a class="md-anchor" id="source"></a>
## Installing from sources {#source}
### Clone the TensorFlow repository <a class="md-anchor" id="AUTOGENERATED-clone-the-tensorflow-repository"></a>
### Clone the TensorFlow repository
```bash
$ git clone --recurse-submodules https://github.com/tensorflow/tensorflow
@ -169,9 +169,9 @@ $ git clone --recurse-submodules https://github.com/tensorflow/tensorflow
`--recurse-submodules` is required to fetch the protobuf library that TensorFlow
depends on.
### Installation for Linux <a class="md-anchor" id="AUTOGENERATED-installation-for-linux"></a>
### Installation for Linux
#### Install Bazel <a class="md-anchor" id="AUTOGENERATED-install-bazel"></a>
#### Install Bazel
Follow instructions [here](http://bazel.io/docs/install.html) to install the
@ -190,13 +190,13 @@ downloaded the installer.
Finally, follow the instructions in that script to place bazel into your binary
path.
#### Install other dependencies <a class="md-anchor" id="AUTOGENERATED-install-other-dependencies"></a>
#### Install other dependencies
```bash
$ sudo apt-get install python-numpy swig python-dev
```
#### Optional: Install CUDA (GPUs on Linux) <a class="md-anchor" id="install_cuda"></a>
#### Optional: Install CUDA (GPUs on Linux) {#install_cuda}
In order to build or run TensorFlow with GPU support, both Cuda Toolkit 7.0 and
CUDNN 6.5 V2 from NVIDIA need to be installed.
@ -208,13 +208,13 @@ TensorFlow GPU support requires having a GPU card with NVidia Compute Capability
* NVidia K20
* NVidia K40
##### Download and install Cuda Toolkit 7.0 <a class="md-anchor" id="AUTOGENERATED-download-and-install-cuda-toolkit-7.0"></a>
##### Download and install Cuda Toolkit 7.0
https://developer.nvidia.com/cuda-toolkit-70
Install the toolkit into e.g. `/usr/local/cuda`
##### Download and install CUDNN Toolkit 6.5 <a class="md-anchor" id="AUTOGENERATED-download-and-install-cudnn-toolkit-6.5"></a>
##### Download and install CUDNN Toolkit 6.5
https://developer.nvidia.com/rdp/cudnn-archive
@ -227,7 +227,7 @@ sudo cp cudnn-6.5-linux-x64-v2/cudnn.h /usr/local/cuda/include
sudo cp cudnn-6.5-linux-x64-v2/libcudnn* /usr/local/cuda/lib64
```
##### Configure TensorFlow's canonical view of Cuda libraries <a class="md-anchor" id="AUTOGENERATED-configure-tensorflow-s-canonical-view-of-cuda-libraries"></a>
##### Configure TensorFlow's canonical view of Cuda libraries
From the root of your source tree, run:
``` bash
@ -252,7 +252,7 @@ This creates a canonical set of symbolic links to the Cuda libraries on your sys
Every time you change the Cuda library paths you need to run this step again before
you invoke the bazel build command.
##### Build your target with GPU support. <a class="md-anchor" id="AUTOGENERATED-build-your-target-with-gpu-support."></a>
##### Build your target with GPU support.
From the root of your source tree, run:
```bash
@ -268,7 +268,7 @@ $ bazel-bin/tensorflow/cc/tutorials_example_trainer --use_gpu
Note that "--config=cuda" is needed to enable the GPU support.
##### Enabling Cuda 3.0. <a class="md-anchor" id="AUTOGENERATED-enabling-cuda-3.0."></a>
##### Enabling Cuda 3.0.
TensorFlow officially supports Cuda devices with 3.5 and 5.2 compute
capabilities. In order to enable earlier Cuda devices such as Grid K520, you
need to target Cuda 3.0. This can be done through TensorFlow unofficial
@ -296,7 +296,7 @@ Setting up Cuda nvvm
Configuration finished
```
##### Known issues <a class="md-anchor" id="AUTOGENERATED-known-issues"></a>
##### Known issues
* Although it is possible to build both Cuda and non-Cuda configs under the same
source tree, we recommend to run "bazel clean" when switching between these two
@ -307,30 +307,30 @@ will fail with a clear error message. In the future, we might consider making
this more conveninent by including the configure step in our build process,
given necessary bazel new feature support.
### Installation for Mac OS X <a class="md-anchor" id="AUTOGENERATED-installation-for-mac-os-x"></a>
### Installation for Mac OS X
Mac needs the same set of dependencies as Linux, however their installing those
Mac needs the same set of dependencies as Linux, however installing those
dependencies is different. Here is a set of useful links to help with installing
the dependencies on Mac OS X :
#### Bazel <a class="md-anchor" id="AUTOGENERATED-bazel"></a>
#### Bazel
Look for installation instructions for Mac OS X on
[this](http://bazel.io/docs/install.html) page.
#### SWIG <a class="md-anchor" id="AUTOGENERATED-swig"></a>
#### SWIG
[Mac OS X installation](http://www.swig.org/Doc3.0/Preface.html#Preface_osx_installation).
Notes : You need to install
[PCRE](ftp://ftp.csx.cam.ac.uk/pub/software/programming/pcre/) and *NOT* PCRE2.
#### Numpy <a class="md-anchor" id="AUTOGENERATED-numpy"></a>
#### Numpy
Follow installation instructions [here](http://docs.scipy.org/doc/numpy/user/install.html).
### Create the pip package and install <a class="md-anchor" id="create-pip"></a>
### Create the pip package and install {#create-pip}
```bash
$ bazel build -c opt //tensorflow/tools/pip_package:build_pip_package
@ -344,7 +344,7 @@ $ bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg
$ pip install /tmp/tensorflow_pkg/tensorflow-0.5.0-cp27-none-linux_x86_64.whl
```
## Train your first TensorFlow neural net model <a class="md-anchor" id="AUTOGENERATED-train-your-first-tensorflow-neural-net-model"></a>
## Train your first TensorFlow neural net model
Starting from the root of your source tree, run:
@ -372,9 +372,9 @@ Validation error: 7.0%
...
```
## Common Problems <a class="md-anchor" id="common_install_problems"></a>
## Common Problems {#common_install_problems}
### GPU-related issues <a class="md-anchor" id="AUTOGENERATED-gpu-related-issues"></a>
### GPU-related issues
If you encounter the following when trying to run a TensorFlow program:
@ -384,9 +384,9 @@ ImportError: libcudart.so.7.0: cannot open shared object file: No such file or d
Make sure you followed the the GPU installation [instructions](#install_cuda).
### Pip installation issues <a class="md-anchor" id="AUTOGENERATED-pip-installation-issues"></a>
### Pip installation issues
#### Can't find setup.py <a class="md-anchor" id="AUTOGENERATED-can-t-find-setup.py"></a>
#### Can't find setup.py
If, during `pip install`, you encounter an error like:
@ -403,7 +403,7 @@ pip install --upgrade pip
This may require `sudo`, depending on how `pip` is installed.
#### SSLError: SSL_VERIFY_FAILED <a class="md-anchor" id="AUTOGENERATED-sslerror--ssl_verify_failed"></a>
#### SSLError: SSL_VERIFY_FAILED
If, during pip install from a URL, you encounter an error like:
@ -414,7 +414,7 @@ SSLError: [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed
Solution: Download the wheel manually via curl or wget, and pip install locally.
### On Linux <a class="md-anchor" id="AUTOGENERATED-on-linux"></a>
### On Linux
If you encounter:
@ -427,7 +427,7 @@ SyntaxError: invalid syntax
Solution: make sure you are using Python 2.7.
### On MacOSX <a class="md-anchor" id="AUTOGENERATED-on-macosx"></a>
### On MacOSX
If you encounter:

View File

@ -1,4 +1,4 @@
# Adding a New Op <a class="md-anchor" id="AUTOGENERATED-adding-a-new-op"></a>
# Adding a New Op
PREREQUISITES:
@ -24,30 +24,9 @@ to:
for the Op. This allows shape inference to work with your Op.
* Test the Op, typically in Python.
<!-- TOC-BEGIN This section is generated by neural network: DO NOT EDIT! -->
## Contents
### [Adding a New Op](#AUTOGENERATED-adding-a-new-op)
* [Define the Op's interface](#define_interface)
* [Implement the kernel for the Op](#AUTOGENERATED-implement-the-kernel-for-the-op)
* [Generate the client wrapper](#AUTOGENERATED-generate-the-client-wrapper)
* [The Python Op wrapper](#AUTOGENERATED-the-python-op-wrapper)
* [The C++ Op wrapper](#AUTOGENERATED-the-c---op-wrapper)
* [Verify it works](#AUTOGENERATED-verify-it-works)
* [Validation](#Validation)
* [Op registration](#AUTOGENERATED-op-registration)
* [Attrs](#Attrs)
* [Attr types](#AUTOGENERATED-attr-types)
* [Polymorphism](#Polymorphism)
* [Inputs and Outputs](#AUTOGENERATED-inputs-and-outputs)
* [Backwards compatibility](#AUTOGENERATED-backwards-compatibility)
* [GPU Support](#mult-archs)
* [Implement the gradient in Python](#AUTOGENERATED-implement-the-gradient-in-python)
* [Implement a shape function in Python](#AUTOGENERATED-implement-a-shape-function-in-python)
[TOC]
<!-- TOC-END This section was generated by neural network, THANKS FOR READING! -->
## Define the Op's interface <a class="md-anchor" id="define_interface"></a>
## Define the Op's interface {#define_interface}
You define the interface of an Op by registering it with the TensorFlow system.
In the registration, you specify the name of your Op, its inputs (types and
@ -73,7 +52,7 @@ outputs a tensor `zeroed` of 32-bit integers.
> A note on naming: The name of the Op should be unique and CamelCase. Names
> starting with an underscore (`_`) are reserved for internal use.
## Implement the kernel for the Op <a class="md-anchor" id="AUTOGENERATED-implement-the-kernel-for-the-op"></a>
## Implement the kernel for the Op
After you define the interface, provide one or more implementations of the Op.
To create one of these kernels, create a class that extends `OpKernel` and
@ -131,8 +110,8 @@ Once you
[build and reinstall TensorFlow](../../get_started/os_setup.md#create-pip), the
Tensorflow system can reference and use the Op when requested.
## Generate the client wrapper <a class="md-anchor" id="AUTOGENERATED-generate-the-client-wrapper"></a>
### The Python Op wrapper <a class="md-anchor" id="AUTOGENERATED-the-python-op-wrapper"></a>
## Generate the client wrapper
### The Python Op wrapper
Python op wrappers are created automatically in
`bazel-genfiles/tensorflow/python/ops/gen_user_ops.py` for all ops placed in the
@ -176,7 +155,7 @@ def my_fact():
return gen_user_ops._fact()
```
### The C++ Op wrapper <a class="md-anchor" id="AUTOGENERATED-the-c---op-wrapper"></a>
### The C++ Op wrapper
C++ op wrappers are created automatically for all ops placed in the
[`tensorflow/core/user_ops`][user_ops] directory, when you build Tensorflow. For
@ -191,7 +170,7 @@ statement
#include "tensorflow/cc/ops/user_ops.h"
```
## Verify it works <a class="md-anchor" id="AUTOGENERATED-verify-it-works"></a>
## Verify it works
A good way to verify that you've successfully implemented your Op is to write a
test for it. Create the file
@ -214,7 +193,7 @@ Then run your test:
$ bazel test tensorflow/python:zero_out_op_test
```
## Validation <a class="md-anchor" id="Validation"></a>
## Validation {#Validation}
The example above assumed that the Op applied to a tensor of any shape. What
if it only applied to vectors? That means adding a check to the above OpKernel
@ -253,9 +232,9 @@ function is an error, and if so return it, use
[`OP_REQUIRES_OK`][validation-macros]. Both of these macros return from the
function on error.
## Op registration <a class="md-anchor" id="AUTOGENERATED-op-registration"></a>
## Op registration
### Attrs <a class="md-anchor" id="Attrs"></a>
### Attrs {#Attrs}
Ops can have attrs, whose values are set when the Op is added to a graph. These
are used to configure the Op, and their values can be accessed both within the
@ -339,7 +318,7 @@ which can then be used in the `Compute` method:
> .Output("zeroed: int32");
> </pre></code>
### Attr types <a class="md-anchor" id="AUTOGENERATED-attr-types"></a>
### Attr types
The following types are supported in an attr:
@ -355,7 +334,7 @@ The following types are supported in an attr:
See also: [`op_def_builder.cc:FinalizeAttr`][FinalizeAttr] for a definitive list.
#### Default values & constraints <a class="md-anchor" id="AUTOGENERATED-default-values---constraints"></a>
#### Default values & constraints
Attrs may have default values, and some types of attrs can have constraints. To
define an attr with constraints, you can use the following `<attr-type-expr>`s:
@ -456,8 +435,8 @@ REGISTER_OP("AttrDefaultExampleForAllTypes")
Note in particular that the values of type `type` use [the `DT_*` names
for the types](../../resources/dims_types.md#data-types).
### Polymorphism <a class="md-anchor" id="Polymorphism"></a>
#### Type Polymorphism <a class="md-anchor" id="type-polymorphism"></a>
### Polymorphism {#Polymorphism}
#### Type Polymorphism {#type-polymorphism}
For ops that can take different types as input or produce different output
types, you can specify [an attr](#attrs) in
@ -685,7 +664,7 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);
#undef REGISTER_KERNEL
```
#### List Inputs and Outputs <a class="md-anchor" id="list-input-output"></a>
#### List Inputs and Outputs {#list-input-output}
In addition to being able to accept or produce different types, ops can consume
or produce a variable number of tensors.
@ -760,7 +739,7 @@ REGISTER_OP("MinimumLengthPolymorphicListExample")
.Output("out: T");
```
### Inputs and Outputs <a class="md-anchor" id="AUTOGENERATED-inputs-and-outputs"></a>
### Inputs and Outputs
To summarize the above, an Op registration can have multiple inputs and outputs:
@ -861,7 +840,7 @@ expressions:
For more details, see
[`tensorflow/core/framework/op_def_builder.h`][op_def_builder].
### Backwards compatibility <a class="md-anchor" id="AUTOGENERATED-backwards-compatibility"></a>
### Backwards compatibility
In general, changes to specifications must be backwards-compatible: changing the
specification of an Op must not break prior serialized GraphDefs constructed
@ -907,7 +886,7 @@ The full list of safe and unsafe changes can be found in
If you cannot make your change to an operation backwards compatible, then create
a new operation with a new name with the new semantics.
## GPU Support <a class="md-anchor" id="mult-archs"></a>
## GPU Support {#mult-archs}
You can implement different OpKernels and register one for CPU and another for
GPU, just like you can [register kernels for different types](#Polymorphism).
@ -935,7 +914,7 @@ kept on the CPU, add a `HostMemory()` call to the kernel registration, e.g.:
PadOp<GPUDevice, T>)
```
## Implement the gradient in Python <a class="md-anchor" id="AUTOGENERATED-implement-the-gradient-in-python"></a>
## Implement the gradient in Python
Given a graph of ops, TensorFlow uses automatic differentiation
(backpropagation) to add new ops representing gradients with respect to the
@ -1012,7 +991,7 @@ Note that at the time the gradient function is called, only the data flow graph
of ops is available, not the tensor data itself. Thus, all computation must be
performed using other tensorflow ops, to be run at graph execution time.
## Implement a shape function in Python <a class="md-anchor" id="AUTOGENERATED-implement-a-shape-function-in-python"></a>
## Implement a shape function in Python
The TensorFlow Python API has a feature called "shape inference" that provides
information about the shapes of tensors without having to execute the

View File

@ -1,4 +1,4 @@
# TensorBoard: Graph Visualization <a class="md-anchor" id="AUTOGENERATED-tensorboard--graph-visualization"></a>
# TensorBoard: Graph Visualization
TensorFlow computation graphs are powerful but complicated. The graph visualization can help you understand and debug them. Here's an example of the visualization at work.
@ -7,7 +7,7 @@ TensorFlow computation graphs are powerful but complicated. The graph visualizat
To see your own graph, run TensorBoard pointing it to the log directory of the job, click on the graph tab on the top pane and select the appropriate run using the menu at the upper left corner. For in depth information on how to run TensorBoard and make sure you are logging all the necessary information, see [Summaries and TensorBoard](../../how_tos/summaries_and_tensorboard/index.md).
## Name scoping and nodes <a class="md-anchor" id="AUTOGENERATED-name-scoping-and-nodes"></a>
## Name scoping and nodes
Typical TensorFlow graphs can have many thousands of nodes--far too many to see
easily all at once, or even to lay out using standard graph tools. To simplify,
@ -142,7 +142,7 @@ Symbol | Meaning
![Control dependency edge](./control_edge.png "Control dependency edge") | Edge showing the control dependency between operations.
![Reference edge](./reference_edge.png "Reference edge") | A reference edge showing that the outgoing operation node can mutate the incoming tensor.
## Interaction <a class="md-anchor" id="AUTOGENERATED-interaction"></a>
## Interaction
Navigate the graph by panning and zooming. Click and drag to pan, and use a
scroll gesture to zoom. Double-click on a node, or click on its `+` button, to

View File

@ -1,4 +1,4 @@
# Custom Data Readers <a class="md-anchor" id="AUTOGENERATED-custom-data-readers"></a>
# Custom Data Readers
PREREQUISITES:
@ -20,16 +20,9 @@ For example, to read a
followed by
[an Op that parses CSV data from a line of text](../../api_docs/python/io_ops.md#decode_csv).
<!-- TOC-BEGIN This section is generated by neural network: DO NOT EDIT! -->
## Contents
### [Custom Data Readers](#AUTOGENERATED-custom-data-readers)
* [Writing a Reader for a file format](#AUTOGENERATED-writing-a-reader-for-a-file-format)
* [Writing an Op for a record format](#AUTOGENERATED-writing-an-op-for-a-record-format)
[TOC]
<!-- TOC-END This section was generated by neural network, THANKS FOR READING! -->
## Writing a Reader for a file format <a class="md-anchor" id="AUTOGENERATED-writing-a-reader-for-a-file-format"></a>
## Writing a Reader for a file format
A `Reader` is something that reads records from a file. There are some examples
of Reader Ops already built into TensorFlow:
@ -196,7 +189,7 @@ ops.RegisterShape("SomeReader")(common_shapes.scalar_shape)
You can see some examples in
[`tensorflow/python/ops/io_ops.py`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/python/ops/io_ops.py).
## Writing an Op for a record format <a class="md-anchor" id="AUTOGENERATED-writing-an-op-for-a-record-format"></a>
## Writing an Op for a record format
Generally this is an ordinary op that takes a scalar string record as input, and
so follow [the instructions to add an Op](../../how_tos/adding_an_op/index.md). You may

View File

@ -1,4 +1,4 @@
# Reading data <a class="md-anchor" id="AUTOGENERATED-reading-data"></a>
# Reading data
There are three main methods of getting data into a TensorFlow program:
@ -8,25 +8,9 @@ There are three main methods of getting data into a TensorFlow program:
* Preloaded data: a constant or variable in the TensorFlow graph holds
all the data (for small data sets).
<!-- TOC-BEGIN This section is generated by neural network: DO NOT EDIT! -->
## Contents
### [Reading data](#AUTOGENERATED-reading-data)
* [Feeding](#Feeding)
* [Reading from files](#AUTOGENERATED-reading-from-files)
* [Filenames, shuffling, and epoch limits](#AUTOGENERATED-filenames--shuffling--and-epoch-limits)
* [File formats](#AUTOGENERATED-file-formats)
* [Preprocessing](#AUTOGENERATED-preprocessing)
* [Batching](#AUTOGENERATED-batching)
* [Creating threads to prefetch using `QueueRunner` objects](#QueueRunner)
* [Filtering records or producing multiple examples per record](#AUTOGENERATED-filtering-records-or-producing-multiple-examples-per-record)
* [Sparse input data](#AUTOGENERATED-sparse-input-data)
* [Preloaded data](#AUTOGENERATED-preloaded-data)
* [Multiple input pipelines](#AUTOGENERATED-multiple-input-pipelines)
[TOC]
<!-- TOC-END This section was generated by neural network, THANKS FOR READING! -->
## Feeding <a class="md-anchor" id="Feeding"></a>
## Feeding {#Feeding}
TensorFlow's feed mechanism lets you inject data into any Tensor in a
computation graph. A python computation can thus feed data directly into the
@ -54,7 +38,7 @@ in
[`tensorflow/g3doc/tutorials/mnist/fully_connected_feed.py`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/g3doc/tutorials/mnist/fully_connected_feed.py),
and is described in the [MNIST tutorial](../../tutorials/mnist/tf/index.md).
## Reading from files <a class="md-anchor" id="AUTOGENERATED-reading-from-files"></a>
## Reading from files
A typical pipeline for reading records from files has the following stages:
@ -67,7 +51,7 @@ A typical pipeline for reading records from files has the following stages:
7. *Optional* preprocessing
8. Example queue
### Filenames, shuffling, and epoch limits <a class="md-anchor" id="AUTOGENERATED-filenames--shuffling--and-epoch-limits"></a>
### Filenames, shuffling, and epoch limits
For the list of filenames, use either a constant string Tensor (like
`["file0", "file1"]` or `[("file%d" % i) for i in range(2)]`) or the
@ -89,7 +73,7 @@ The queue runner works in a thread separate from the reader that pulls
filenames from the queue, so the shuffling and enqueuing process does not
block the reader.
### File formats <a class="md-anchor" id="AUTOGENERATED-file-formats"></a>
### File formats
Select the reader that matches your input file format and pass the filename
queue to the reader's read method. The read method outputs a key identifying
@ -97,7 +81,7 @@ the file and record (useful for debugging if you have some weird records), and
a scalar string value. Use one (or more) of the decoder and conversion ops to
decode this string into the tensors that make up an example.
#### CSV files <a class="md-anchor" id="AUTOGENERATED-csv-files"></a>
#### CSV files
To read text files in [comma-separated value (CSV)
format](https://tools.ietf.org/html/rfc4180), use a
@ -139,7 +123,7 @@ You must call `tf.train.start_queue_runners` to populate the queue before
you call `run` or `eval` to execute the `read`. Otherwise `read` will
block while it waits for filenames from the queue.
#### Fixed length records <a class="md-anchor" id="AUTOGENERATED-fixed-length-records"></a>
#### Fixed length records
To read binary files in which each record is a fixed number of bytes, use
[`tf.FixedLengthRecordReader`](../../api_docs/python/io_ops.md#FixedLengthRecordReader)
@ -155,7 +139,7 @@ needed. For CIFAR-10, you can see how to do the reading and decoding in
and described in
[this tutorial](../../tutorials/deep_cnn/index.md#prepare-the-data).
#### Standard TensorFlow format <a class="md-anchor" id="AUTOGENERATED-standard-tensorflow-format"></a>
#### Standard TensorFlow format
Another approach is to convert whatever data you have into a supported format.
This approach makes it easier to mix and match data sets and network
@ -181,7 +165,7 @@ found in
[`tensorflow/g3doc/how_tos/reading_data/fully_connected_reader.py`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/g3doc/how_tos/reading_data/fully_connected_reader.py),
which you can compare with the `fully_connected_feed` version.
### Preprocessing <a class="md-anchor" id="AUTOGENERATED-preprocessing"></a>
### Preprocessing
You can then do any preprocessing of these examples you want. This would be any
processing that doesn't depend on trainable parameters. Examples include
@ -190,7 +174,7 @@ etc. See
[`tensorflow/models/image/cifar10/cifar10.py`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/models/image/cifar10/cifar10.py)
for an example.
### Batching <a class="md-anchor" id="AUTOGENERATED-batching"></a>
### Batching
At the end of the pipeline we use another queue to batch together examples for
training, evaluation, or inference. For this we use a queue that randomizes the
@ -268,7 +252,7 @@ summary to the graph that indicates how full the example queue is. If you have
enough reading threads, that summary will stay above zero. You can
[view your summaries as training progresses using TensorBoard](../../how_tos/summaries_and_tensorboard/index.md).
### Creating threads to prefetch using `QueueRunner` objects <a class="md-anchor" id="QueueRunner"></a>
### Creating threads to prefetch using `QueueRunner` objects {#QueueRunner}
The short version: many of the `tf.train` functions listed above add
[`QueueRunner`](../../api_docs/python/train.md#QueueRunner) objects to your
@ -312,7 +296,7 @@ coord.join(threads)
sess.close()
```
#### Aside: What is happening here? <a class="md-anchor" id="AUTOGENERATED-aside--what-is-happening-here-"></a>
#### Aside: What is happening here?
First we create the graph. It will have a few pipeline stages that are
connected by queues. The first stage will generate filenames to read and enqueue
@ -357,7 +341,7 @@ exception).
For more about threading, queues, QueueRunners, and Coordinators
[see here](../../how_tos/threading_and_queues/index.md).
#### Aside: How clean shut-down when limiting epochs works <a class="md-anchor" id="AUTOGENERATED-aside--how-clean-shut-down-when-limiting-epochs-works"></a>
#### Aside: How clean shut-down when limiting epochs works
Imagine you have a model that has set a limit on the number of epochs to train
on. That means that the thread generating filenames will only run that many
@ -400,7 +384,7 @@ errors and exiting. Once all the training threads are done,
[`tf.train.Coordinator.join`](../../api_docs/python/train.md#Coordinator.join)
will return and you can exit cleanly.
### Filtering records or producing multiple examples per record <a class="md-anchor" id="AUTOGENERATED-filtering-records-or-producing-multiple-examples-per-record"></a>
### Filtering records or producing multiple examples per record
Instead of examples with shapes `[x, y, z]`, you will produce a batch of
examples with shape `[batch, x, y, z]`. The batch size can be 0 if you want to
@ -409,14 +393,14 @@ are producing multiple examples per record. Then simply set `enqueue_many=True`
when calling one of the batching functions (such as `shuffle_batch` or
`shuffle_batch_join`).
### Sparse input data <a class="md-anchor" id="AUTOGENERATED-sparse-input-data"></a>
### Sparse input data
SparseTensors don't play well with queues. If you use SparseTensors you have
to decode the string records using
[`tf.parse_example`](../../api_docs/python/io_ops.md#parse_example) **after**
batching (instead of using `tf.parse_single_example` before batching).
## Preloaded data <a class="md-anchor" id="AUTOGENERATED-preloaded-data"></a>
## Preloaded data
This is only used for small data sets that can be loaded entirely in memory.
There are two approaches:
@ -475,7 +459,7 @@ An MNIST example that preloads the data using constants can be found in
You can compare these with the `fully_connected_feed` and
`fully_connected_reader` versions above.
## Multiple input pipelines <a class="md-anchor" id="AUTOGENERATED-multiple-input-pipelines"></a>
## Multiple input pipelines
Commonly you will want to train on one dataset and evaluate (or "eval") on
another. One way to do this is to actually have two separate processes:

View File

@ -1,4 +1,4 @@
# TensorBoard: Visualizing Learning <a class="md-anchor" id="AUTOGENERATED-tensorboard--visualizing-learning"></a>
# TensorBoard: Visualizing Learning
The computations you'll use TensorBoard for - like training a massive
deep neural network - can be complex and confusing. To make it easier to
@ -11,7 +11,7 @@ TensorBoard is fully configured, it looks like this:
![MNIST TensorBoard](./mnist_tensorboard.png "MNIST TensorBoard")
## Serializing the data <a class="md-anchor" id="AUTOGENERATED-serializing-the-data"></a>
## Serializing the data
TensorBoard operates by reading TensorFlow events files, which contain summary
data that you can generate when running TensorFlow. Here's the general
@ -79,7 +79,7 @@ while training:
You're now all set to visualize this data using TensorBoard.
## Launching TensorBoard <a class="md-anchor" id="AUTOGENERATED-launching-tensorboard"></a>
## Launching TensorBoard
To run TensorBoard, use the command

View File

@ -1,4 +1,4 @@
# Threading and Queues <a class="md-anchor" id="AUTOGENERATED-threading-and-queues"></a>
# Threading and Queues
Queues are a powerful mechanism for asynchronous computation using TensorFlow.
@ -24,7 +24,7 @@ API, they are methods of the queue object (eg. `q.enqueue(...)`).
Now that you have a bit of a feel for queues, let's dive into the details...
## Queue Use Overview <a class="md-anchor" id="AUTOGENERATED-queue-use-overview"></a>
## Queue Use Overview
Queues, such as `FIFOQueue` and `RandomShuffleQueue`, are important TensorFlow
objects for computing tensors asynchronously in a graph.
@ -54,7 +54,7 @@ stop together and report exceptions to a program that waits for them to stop.
The `QueueRunner` class is used to create a number of threads cooperating to
enqueue tensors in the same queue.
## Coordinator <a class="md-anchor" id="AUTOGENERATED-coordinator"></a>
## Coordinator
The Coordinator class helps multiple threads stop together.
@ -96,7 +96,7 @@ Obviously, the coordinator can manage threads doing very different things.
They don't have to be all the same as in the example above. The coordinator
also has support to capture and report exceptions. See the [Coordinator class](../../api_docs/python/train.md#Coordinator) documentation for more details.
## QueueRunner <a class="md-anchor" id="AUTOGENERATED-queuerunner"></a>
## QueueRunner
The `QueueRunner` class creates a number of threads that repeatedly run an
enqueue op. These threads can use a coordinator to stop together. In
@ -145,7 +145,7 @@ coord.request_stop()
coord.join(threads)
```
## Handling Exceptions <a class="md-anchor" id="AUTOGENERATED-handling-exceptions"></a>
## Handling Exceptions
Threads started by queue runners do more than just run the enqueue ops. They
also catch and handle exceptions generated by queues, including

View File

@ -1,6 +1,6 @@
# Using GPUs <a class="md-anchor" id="AUTOGENERATED-using-gpus"></a>
# Using GPUs
## Supported devices <a class="md-anchor" id="AUTOGENERATED-supported-devices"></a>
## Supported devices
On a typical system, there are multiple computing devices. In TensorFlow, the
supported device types are `CPU` and `GPU`. They are represented as
@ -16,7 +16,7 @@ a device. For example, `matmul` has both CPU and GPU kernels. On a
system with devices `cpu:0` and `gpu:0`, `gpu:0` will be selected to run
`matmul`.
## Logging Device placement <a class="md-anchor" id="AUTOGENERATED-logging-device-placement"></a>
## Logging Device placement
To find out which devices your operations and tensors are assigned to, create
the session with `log_device_placement` configuration option set to `True`.
@ -46,7 +46,7 @@ MatMul: /job:localhost/replica:0/task:0/gpu:0
```
## Manual device placement <a class="md-anchor" id="AUTOGENERATED-manual-device-placement"></a>
## Manual device placement
If you would like a particular operation to run on a device of your
choice instead of what's automatically selected for you, you can use
@ -78,7 +78,7 @@ MatMul: /job:localhost/replica:0/task:0/gpu:0
[ 49. 64.]]
```
## Using a single GPU on a multi-GPU system <a class="md-anchor" id="AUTOGENERATED-using-a-single-gpu-on-a-multi-gpu-system"></a>
## Using a single GPU on a multi-GPU system
If you have more than one GPU in your system, the GPU with the lowest ID will be
selected by default. If you would like to run on a different GPU, you will need
@ -125,7 +125,7 @@ sess = tf.Session(config=tf.ConfigProto(
print sess.run(c)
```
## Using multiple GPUs <a class="md-anchor" id="AUTOGENERATED-using-multiple-gpus"></a>
## Using multiple GPUs
If you would like to run TensorFlow on multiple GPUs, you can construct your
model in a multi-tower fashion where each tower is assigned to a different GPU.

View File

@ -1,4 +1,4 @@
# Sharing Variables <a class="md-anchor" id="AUTOGENERATED-sharing-variables"></a>
# Sharing Variables
You can create, initialize, save and load single variables
in the way described in the [Variables HowTo](../../how_tos/variables/index.md).
@ -7,7 +7,7 @@ variables and you might want to initialize all of them in one place.
This tutorial shows how this can be done using `tf.variable_scope()` and
the `tf.get_variable()`.
## The Problem <a class="md-anchor" id="AUTOGENERATED-the-problem"></a>
## The Problem
Imagine you create a simple model for image filters, similar to our
[Convolutional Neural Networks Tutorial](../../tutorials/deep_cnn/index.md)
@ -88,7 +88,7 @@ For a lighter solution, not involving classes, TensorFlow provides
a *Variable Scope* mechanism that allows to easily share named variables
while constructing a graph.
## Variable Scope Example <a class="md-anchor" id="AUTOGENERATED-variable-scope-example"></a>
## Variable Scope Example
Variable Scope mechanism in TensorFlow consists of 2 main functions:
@ -162,9 +162,9 @@ with tf.variable_scope("image_filters") as scope:
This is a good way to share variables, lightweight and safe.
## How Does Variable Scope Work? <a class="md-anchor" id="AUTOGENERATED-how-does-variable-scope-work-"></a>
## How Does Variable Scope Work?
### Understanding `tf.get_variable()` <a class="md-anchor" id="AUTOGENERATED-understanding--tf.get_variable---"></a>
### Understanding `tf.get_variable()`
To understand variable scope it is necessary to first
fully understand how `tf.get_variable()` works.
@ -210,7 +210,7 @@ with tf.variable_scope("foo", reuse=True):
assert v1 == v
```
### Basics of `tf.variable_scope()` <a class="md-anchor" id="AUTOGENERATED-basics-of--tf.variable_scope---"></a>
### Basics of `tf.variable_scope()`
Knowing how `tf.get_variable()` works makes it easy to understand variable
scope. The primary function of variable scope is to carry a name that will
@ -268,7 +268,7 @@ with tf.variable_scope("root"):
assert tf.get_variable_scope().reuse == False
```
### Capturing variable scope <a class="md-anchor" id="AUTOGENERATED-capturing-variable-scope"></a>
### Capturing variable scope
In all examples presented above, we shared parameters only because their
names agreed, that is, because we opened a reusing variable scope with
@ -303,7 +303,7 @@ with tf.variable_scope("bar")
assert foo_scope2.name == "foo" # Not changed.
```
### Initializers in variable scope <a class="md-anchor" id="AUTOGENERATED-initializers-in-variable-scope"></a>
### Initializers in variable scope
Using `tf.get_variable()` allows to write functions that create or reuse
variables and can be transparently called from outside. But what if we wanted
@ -329,7 +329,7 @@ with tf.variable_scope("foo", initializer=tf.constant_initializer(0.4)):
assert v.eval() == 0.2 # Changed default initializer.
```
### Names of ops in `tf.variable_scope()` <a class="md-anchor" id="AUTOGENERATED-names-of-ops-in--tf.variable_scope---"></a>
### Names of ops in `tf.variable_scope()`
We discussed how `tf.variable_scope` governs the names of variables.
But how does it influence the names of other ops in the scope?
@ -359,7 +359,7 @@ When opening a variable scope using a captured object instead of a string,
we do not alter the current name scope for ops.
## Examples of Use <a class="md-anchor" id="AUTOGENERATED-examples-of-use"></a>
## Examples of Use
Here are pointers to a few files that make use of variable scope.
In particular, it is heavily used for recurrent neural networks

View File

@ -1,4 +1,4 @@
# Variables: Creation, Initialization, Saving, and Loading <a class="md-anchor" id="AUTOGENERATED-variables--creation--initialization--saving--and-loading"></a>
# Variables: Creation, Initialization, Saving, and Loading
When you train a model, you use [variables](../../api_docs/python/state_ops.md)
to hold and update parameters. Variables are in-memory buffers containing
@ -13,7 +13,7 @@ their reference manual for a complete description of their API:
* The [`tf.train.Saver`](../../api_docs/python/state_ops.md#Saver) class.
## Creation <a class="md-anchor" id="AUTOGENERATED-creation"></a>
## Creation
When you create a [Variable](../../api_docs/python/state_ops.md) you pass a
`Tensor` as its initial value to the `Variable()` constructor. TensorFlow
@ -43,7 +43,7 @@ Calling `tf.Variable()` adds several ops to the graph:
The value returned by `tf.Variable()` value is an instance of the Python class
`tf.Variable`.
## Initialization <a class="md-anchor" id="AUTOGENERATED-initialization"></a>
## Initialization
Variable initializers must be run explicitly before other ops in your model can
be run. The easiest way to do that is to add an op that runs all the variable
@ -74,7 +74,7 @@ with tf.Session() as sess:
...
```
### Initialization from another Variable <a class="md-anchor" id="AUTOGENERATED-initialization-from-another-variable"></a>
### Initialization from another Variable
You sometimes need to initialize a variable from the initial value of another
variable. As the op added by `tf.initialize_all_variables()` initializes all
@ -96,7 +96,7 @@ w2 = tf.Variable(weights.initialized_value(), name="w2")
w_twice = tf.Variable(weights.initialized_value() * 2.0, name="w_twice")
```
### Custom Initialization <a class="md-anchor" id="AUTOGENERATED-custom-initialization"></a>
### Custom Initialization
The convenience function `tf.initialize_all_variables()` adds an op to
initialize *all variables* in the model. You can also pass it an explicit list
@ -104,7 +104,7 @@ of variables to initialize. See the
[Variables Documentation](../../api_docs/python/state_ops.md) for more options,
including checking if variables are initialized.
## Saving and Restoring <a class="md-anchor" id="AUTOGENERATED-saving-and-restoring"></a>
## Saving and Restoring
The easiest way to save and restore a model is to use a `tf.train.Saver` object.
The constructor adds `save` and `restore` ops to the graph for all, or a
@ -112,7 +112,7 @@ specified list, of the variables in the graph. The saver object provides
methods to run these ops, specifying paths for the checkpoint files to write to
or read from.
### Checkpoint Files <a class="md-anchor" id="AUTOGENERATED-checkpoint-files"></a>
### Checkpoint Files
Variables are saved in binary files that, roughly, contain a map from variable
names to tensor values.
@ -122,7 +122,7 @@ variables in the checkpoint files. By default, it uses the value of the
[`Variable.name`](../../api_docs/python/state_ops.md#Variable.name) property for
each variable.
### Saving Variables <a class="md-anchor" id="AUTOGENERATED-saving-variables"></a>
### Saving Variables
Create a `Saver` with `tf.train.Saver()` to manage all variables in
the model.
@ -149,7 +149,7 @@ with tf.Session() as sess:
print "Model saved in file: ", save_path
```
### Restoring Variables <a class="md-anchor" id="AUTOGENERATED-restoring-variables"></a>
### Restoring Variables
The same `Saver` object is used to restore variables. Note that when you
restore variables from a file you do not have to initialize them beforehand.
@ -172,7 +172,7 @@ with tf.Session() as sess:
...
```
### Choosing which Variables to Save and Restore <a class="md-anchor" id="AUTOGENERATED-choosing-which-variables-to-save-and-restore"></a>
### Choosing which Variables to Save and Restore
If you do not pass any argument to `tf.train.Saver()` the saver handles all
variables in the graph. Each one of them is saved under the name that was

View File

@ -1,8 +1,8 @@
# TensorFlow <a class="md-anchor" id="AUTOGENERATED-tensorflow"></a>
# TensorFlow
<!-- Note: This file is ignored in building the external site tensorflow.org -->
## Introduction <a class="md-anchor" id="AUTOGENERATED-introduction"></a>
## Introduction
TensorFlow&#8482; is an open source software library for numerical computation
using data flow graphs. Nodes in the graph represent mathematical operations,
@ -16,6 +16,6 @@ neural networks research. The system is general enough to be applicable in a
wide variety of other domains as well. The following documents show you how
to set up and use the TensorFlow system.
## Table of Contents <a class="md-anchor" id="AUTOGENERATED-table-of-contents"></a>
## Table of Contents
<!--#include virtual="sitemap.md" -->

View File

@ -1,4 +1,4 @@
# BibTex Citation <a class="md-anchor" id="AUTOGENERATED-bibtex-citation"></a>
# BibTex Citation
If you use TensorFlow in your research and would like to cite the TensorFlow
system, we suggest you cite the following whitepaper:

View File

@ -1,11 +1,11 @@
# Tensor Ranks, Shapes, and Types <a class="md-anchor" id="AUTOGENERATED-tensor-ranks--shapes--and-types"></a>
# Tensor Ranks, Shapes, and Types
TensorFlow programs use a tensor data structure to represent all data. You can
think of a TensorFlow tensor as an n-dimensional array or list.
A tensor has a static type and dynamic dimensions. Only tensors may be passed
between nodes in the computation graph.
## Rank <a class="md-anchor" id="AUTOGENERATED-rank"></a>
## Rank
In the TensorFlow system, tensors are described by a unit of dimensionality
known as *rank*. Tensor rank is not the same as matrix rank. Tensor rank
@ -28,7 +28,7 @@ Rank | Math entity | Python example
3 | 3-Tensor (cube of numbers) | `t = [[[2], [4], [6]], [[8], [10], [12]], [[14], [16], [18]]]`
n | n-Tensor (you get the idea) | `....`
## Shape <a class="md-anchor" id="AUTOGENERATED-shape"></a>
## Shape
The TensorFlow documentation uses three notational conventions to describe
tensor dimensionality: rank, shape, and dimension number. The following table
@ -45,7 +45,7 @@ n | [D0, D1, ... Dn] | n-D | A tensor with shape [D0, D1, ... Dn].
Shapes can be represented via Python lists / tuples of ints, or with the
[`TensorShape` class](../api_docs/python/framework.md#TensorShape).
## Data types <a class="md-anchor" id="AUTOGENERATED-data-types"></a>
## Data types
In addition to dimensionality, Tensors have a data type. You can assign any one
of the following data types to a tensor:

View File

@ -1,29 +1,17 @@
# Frequently Asked Questions <a class="md-anchor" id="AUTOGENERATED-frequently-asked-questions"></a>
# Frequently Asked Questions
This document provides answers to some of the frequently asked questions about
TensorFlow. If you have a question that is not covered here, you might find an
answer on one of the TensorFlow [community resources](../resources/index.md).
<!-- TOC-BEGIN This section is generated by neural network: DO NOT EDIT! -->
## Contents
### [Frequently Asked Questions](#AUTOGENERATED-frequently-asked-questions)
* [Building a TensorFlow graph](#AUTOGENERATED-building-a-tensorflow-graph)
* [Running a TensorFlow computation](#AUTOGENERATED-running-a-tensorflow-computation)
* [Variables](#AUTOGENERATED-variables)
* [Tensor shapes](#AUTOGENERATED-tensor-shapes)
* [TensorBoard](#AUTOGENERATED-tensorboard)
* [Extending TensorFlow](#AUTOGENERATED-extending-tensorflow)
* [Miscellaneous](#AUTOGENERATED-miscellaneous)
[TOC]
<!-- TOC-END This section was generated by neural network, THANKS FOR READING! -->
## Building a TensorFlow graph <a class="md-anchor" id="AUTOGENERATED-building-a-tensorflow-graph"></a>
## Building a TensorFlow graph
See also the
[API documentation on building graphs](../api_docs/python/framework.md).
#### Why does `c = tf.matmul(a, b)` not execute the matrix multiplication immediately? <a class="md-anchor" id="AUTOGENERATED-why-does--c---tf.matmul-a--b---not-execute-the-matrix-multiplication-immediately-"></a>
#### Why does `c = tf.matmul(a, b)` not execute the matrix multiplication immediately?
In the TensorFlow Python API, `a`, `b`, and `c` are
[`Tensor`](../api_docs/python/framework.md#Tensor) objects. A `Tensor` object is
@ -36,12 +24,12 @@ a dataflow graph. You then offload the computation of the entire dataflow graph
whole computation much more efficiently than executing the operations
one-by-one.
#### How are devices named? <a class="md-anchor" id="AUTOGENERATED-how-are-devices-named-"></a>
#### How are devices named?
The supported device names are `"/device:CPU:0"` (or `"/cpu:0"`) for the CPU
device, and `"/device:GPU:i"` (or `"/gpu:i"`) for the *i*th GPU device.
#### How do I place operations on a particular device? <a class="md-anchor" id="AUTOGENERATED-how-do-i-place-operations-on-a-particular-device-"></a>
#### How do I place operations on a particular device?
To place a group of operations on a device, create them within a
[`with tf.device(name):`](../api_docs/python/framework.md#device) context. See
@ -51,17 +39,17 @@ TensorFlow assigns operations to devices, and the
[CIFAR-10 tutorial](../tutorials/deep_cnn/index.md) for an example model that
uses multiple GPUs.
#### What are the different types of tensors that are available? <a class="md-anchor" id="AUTOGENERATED-what-are-the-different-types-of-tensors-that-are-available-"></a>
#### What are the different types of tensors that are available?
TensorFlow supports a variety of different data types and tensor shapes. See the
[ranks, shapes, and types reference](../resources/dims_types.md) for more details.
## Running a TensorFlow computation <a class="md-anchor" id="AUTOGENERATED-running-a-tensorflow-computation"></a>
## Running a TensorFlow computation
See also the
[API documentation on running graphs](../api_docs/python/client.md).
#### What's the deal with feeding and placeholders? <a class="md-anchor" id="AUTOGENERATED-what-s-the-deal-with-feeding-and-placeholders-"></a>
#### What's the deal with feeding and placeholders?
Feeding is a mechanism in the TensorFlow Session API that allows you to
substitute different values for one or more tensors at run time. The `feed_dict`
@ -78,7 +66,7 @@ their shape as well. See the
example of how placeholders and feeding can be used to provide the training data
for a neural network.
#### What is the difference between `Session.run()` and `Tensor.eval()`? <a class="md-anchor" id="AUTOGENERATED-what-is-the-difference-between--session.run----and--tensor.eval----"></a>
#### What is the difference between `Session.run()` and `Tensor.eval()`?
If `t` is a [`Tensor`](../api_docs/python/framework.md#Tensor) object,
[`t.eval()`](../api_docs/python/framework.md#Tensor.eval) is shorthand for
@ -105,7 +93,7 @@ the `with` block. The context manager approach can lead to more concise code for
simple use cases (like unit tests); if your code deals with multiple graphs and
sessions, it may be more straightforward to explicit calls to `Session.run()`.
#### Do Sessions have a lifetime? What about intermediate tensors? <a class="md-anchor" id="AUTOGENERATED-do-sessions-have-a-lifetime--what-about-intermediate-tensors-"></a>
#### Do Sessions have a lifetime? What about intermediate tensors?
Sessions can own resources, such
[variables](../api_docs/python/state_ops.md#Variable),
@ -119,13 +107,13 @@ The intermediate tensors that are created as part of a call to
[`Session.run()`](../api_docs/python/client.md) will be freed at or before the
end of the call.
#### Can I run distributed training on multiple computers? <a class="md-anchor" id="AUTOGENERATED-can-i-run-distributed-training-on-multiple-computers-"></a>
#### Can I run distributed training on multiple computers?
The initial open-source release of TensorFlow supports multiple devices (CPUs
and GPUs) in a single computer. We are working on a distributed version as well:
if you are interested, please let us know so we can prioritize accordingly.
#### Does the runtime parallelize parts of graph execution? <a class="md-anchor" id="AUTOGENERATED-does-the-runtime-parallelize-parts-of-graph-execution-"></a>
#### Does the runtime parallelize parts of graph execution?
The TensorFlow runtime parallelizes graph execution across many different
dimensions:
@ -140,7 +128,7 @@ dimensions:
enables the runtime to get higher throughput, if a single step does not use
all of the resources in your computer.
#### Which client languages are supported in TensorFlow? <a class="md-anchor" id="AUTOGENERATED-which-client-languages-are-supported-in-tensorflow-"></a>
#### Which client languages are supported in TensorFlow?
TensorFlow is designed to support multiple client languages. Currently, the
best-supported client language is [Python](../api_docs/python/index.md). The
@ -154,7 +142,7 @@ interest. TensorFlow has a
that makes it easy to build a client in many different languages. We invite
contributions of new language bindings.
#### Does TensorFlow make use of all the devices (GPUs and CPUs) available on my machine? <a class="md-anchor" id="AUTOGENERATED-does-tensorflow-make-use-of-all-the-devices--gpus-and-cpus--available-on-my-machine-"></a>
#### Does TensorFlow make use of all the devices (GPUs and CPUs) available on my machine?
TensorFlow supports multiple GPUs and CPUs. See the how-to documentation on
[using GPUs with TensorFlow](../how_tos/using_gpu/index.md) for details of how
@ -165,7 +153,7 @@ uses multiple GPUs.
Note that TensorFlow only uses GPU devices with a compute capability greater
than 3.5.
#### Why does `Session.run()` hang when using a reader or a queue? <a class="md-anchor" id="AUTOGENERATED-why-does--session.run----hang-when-using-a-reader-or-a-queue-"></a>
#### Why does `Session.run()` hang when using a reader or a queue?
The [reader](../api_docs/python/io_ops.md#ReaderBase) and
[queue](../api_docs/python/io_ops.md#QueueBase) classes provide special operations that
@ -177,20 +165,20 @@ for
[using `QueueRunner` objects to drive queues and readers](../how_tos/reading_data/index.md#QueueRunners)
for more information on how to use them.
## Variables <a class="md-anchor" id="AUTOGENERATED-variables"></a>
## Variables
See also the how-to documentation on [variables](../how_tos/variables/index.md)
and [variable scopes](../how_tos/variable_scope/index.md), and
[the API documentation for variables](../api_docs/python/state_ops.md).
#### What is the lifetime of a variable? <a class="md-anchor" id="AUTOGENERATED-what-is-the-lifetime-of-a-variable-"></a>
#### What is the lifetime of a variable?
A variable is created when you first run the
[`tf.Variable.initializer`](../api_docs/python/state_ops.md#Variable.initializer)
operation for that variable in a session. It is destroyed when that
[`session is closed`](../api_docs/python/client.md#Session.close).
#### How do variables behave when they are concurrently accessed? <a class="md-anchor" id="AUTOGENERATED-how-do-variables-behave-when-they-are-concurrently-accessed-"></a>
#### How do variables behave when they are concurrently accessed?
Variables allow concurrent read and write operations. The value read from a
variable may change it is concurrently updated. By default, concurrent assigment
@ -198,12 +186,12 @@ operations to a variable are allowed to run with no mutual exclusion. To acquire
a lock when assigning to a variable, pass `use_locking=True` to
[`Variable.assign()`](../api_docs/python/state_ops.md#Variable.assign).
## Tensor shapes <a class="md-anchor" id="AUTOGENERATED-tensor-shapes"></a>
## Tensor shapes
See also the
[`TensorShape` API documentation](../api_docs/python/framework.md#TensorShape).
#### How can I determine the shape of a tensor in Python? <a class="md-anchor" id="AUTOGENERATED-how-can-i-determine-the-shape-of-a-tensor-in-python-"></a>
#### How can I determine the shape of a tensor in Python?
In TensorFlow, a tensor has both a static (inferred) shape and a dynamic (true)
shape. The static shape can be read using the
@ -214,7 +202,7 @@ tensor, and may be
shape is not fully defined, the dynamic shape of a `Tensor` `t` can be
determined by evaluating [`tf.shape(t)`](../api_docs/python/array_ops.md#shape).
#### What is the difference between `x.set_shape()` and `x = tf.reshape(x)`? <a class="md-anchor" id="AUTOGENERATED-what-is-the-difference-between--x.set_shape----and--x---tf.reshape-x---"></a>
#### What is the difference between `x.set_shape()` and `x = tf.reshape(x)`?
The [`tf.Tensor.set_shape()`](../api_docs/python/framework.md) method updates
the static shape of a `Tensor` object, and it is typically used to provide
@ -224,7 +212,7 @@ change the dynamic shape of the tensor.
The [`tf.reshape()`](../api_docs/python/array_ops.md#reshape) operation creates
a new tensor with a different dynamic shape.
#### How do I build a graph that works with variable batch sizes? <a class="md-anchor" id="AUTOGENERATED-how-do-i-build-a-graph-that-works-with-variable-batch-sizes-"></a>
#### How do I build a graph that works with variable batch sizes?
It is often useful to build a graph that works with variable batch sizes, for
example so that the same code can be used for (mini-)batch training, and
@ -250,13 +238,13 @@ to encode the batch size as a Python constant, but instead to use a symbolic
[`tf.placeholder(..., shape=[None, ...])`](../api_docs/python/io_ops.md#placeholder). The
`None` element of the shape corresponds to a variable-sized dimension.
## TensorBoard <a class="md-anchor" id="AUTOGENERATED-tensorboard"></a>
## TensorBoard
#### How can I visualize a TensorFlow graph? <a class="md-anchor" id="AUTOGENERATED-how-can-i-visualize-a-tensorflow-graph-"></a>
#### How can I visualize a TensorFlow graph?
See the [graph visualization tutorial](../how_tos/graph_viz/index.md).
#### What is the simplest way to send data to TensorBoard? <a class="md-anchor" id="AUTOGENERATED-what-is-the-simplest-way-to-send-data-to-tensorboard-"></a>
#### What is the simplest way to send data to TensorBoard?
Add summary ops to your TensorFlow graph, and use a
[`SummaryWriter`](../api_docs/python/train.md#SummaryWriter) to write
@ -267,12 +255,12 @@ these summaries to a log directory. Then, start TensorBoard using
For more details, see the [Summaries and TensorBoard tutorial]
(../how_tos/summaries_and_tensorboard/index.md).
## Extending TensorFlow <a class="md-anchor" id="AUTOGENERATED-extending-tensorflow"></a>
## Extending TensorFlow
See also the how-to documentation for
[adding a new operation to TensorFlow](../how_tos/adding_an_op/index.md).
#### My data is in a custom format. How do I read it using TensorFlow? <a class="md-anchor" id="AUTOGENERATED-my-data-is-in-a-custom-format.-how-do-i-read-it-using-tensorflow-"></a>
#### My data is in a custom format. How do I read it using TensorFlow?
There are two main options for dealing with data in a custom format.
@ -290,7 +278,7 @@ data format. The
[guide to handling new data formats](../how_tos/new_data_formats/index.md) has
more information about the steps for doing this.
#### How do I define an operation that takes a variable number of inputs? <a class="md-anchor" id="AUTOGENERATED-how-do-i-define-an-operation-that-takes-a-variable-number-of-inputs-"></a>
#### How do I define an operation that takes a variable number of inputs?
The TensorFlow op registration mechanism allows you to define inputs that are a
single tensor, a list of tensors with the same type (for example when adding
@ -300,15 +288,15 @@ how-to documentation for
[adding an op with a list of inputs or outputs](../how_tos/adding_an_op/index.md#list-input-output)
for more details of how to define these different input types.
## Miscellaneous <a class="md-anchor" id="AUTOGENERATED-miscellaneous"></a>
## Miscellaneous
#### Does TensorFlow work with Python 3? <a class="md-anchor" id="AUTOGENERATED-does-tensorflow-work-with-python-3-"></a>
#### Does TensorFlow work with Python 3?
We have only tested TensorFlow using Python 2.7. We are aware of some changes
that will be required for Python 3 compatibility, and welcome contributions
towards this effort.
#### What is TensorFlow's coding style convention? <a class="md-anchor" id="AUTOGENERATED-what-is-tensorflow-s-coding-style-convention-"></a>
#### What is TensorFlow's coding style convention?
The TensorFlow Python API adheres to the
[PEP8](https://www.python.org/dev/peps/pep-0008/) conventions.<sup>*</sup> In

View File

@ -1,4 +1,4 @@
# Glossary <a class="md-anchor" id="AUTOGENERATED-glossary"></a>
# Glossary
**Broadcasting operation**

View File

@ -1,14 +1,14 @@
# Additional Resources <a class="md-anchor" id="AUTOGENERATED-additional-resources"></a>
# Additional Resources
## TensorFlow WhitePaper <a class="md-anchor" id="AUTOGENERATED-tensorflow-whitepaper"></a>
## TensorFlow WhitePaper
Additional details about the TensorFlow programming model and the underlying
implementation can be found in out white paper:
* [TensorFlow: Large-scale machine learning on heterogeneous systems](http://download.tensorflow.org/paper/whitepaper2015.pdf)
### Citation <a class="md-anchor" id="AUTOGENERATED-citation"></a>
### Citation
If you use TensorFlow in your research and would like to cite the TensorFlow
system, we suggest you cite the paper above.
@ -21,9 +21,9 @@ endorsed by or otherwise affiliated with Google. When referring to our marks,
please include the following attribution statement: "TensorFlow, the TensorFlow
logo and any related marks are trademarks of Google Inc."
## Community <a class="md-anchor" id="AUTOGENERATED-community"></a>
## Community
### Development <a class="md-anchor" id="AUTOGENERATED-development"></a>
### Development
The source is hosted on GitHub: <https://github.com/tensorflow/tensorflow>.
@ -31,14 +31,14 @@ If you are interested in contributing to TensorFlow please
[review the contributing guide](
https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md).
### Help / Support / How do I? <a class="md-anchor" id="AUTOGENERATED-help---support---how-do-i-"></a>
### Help / Support / How do I?
For help and support, technical or algorithmic questions, please submit
your questions to Stack Overflow:
<https://stackoverflow.com/questions/tagged/tensorflow>.
Please do not use the mailing list or issue tracker for support.
### Discussions <a class="md-anchor" id="AUTOGENERATED-discussions"></a>
### Discussions
For general discussions, please join the [TensorFlow discuss mailing list](
https://groups.google.com/a/tensorflow.org/d/forum/discuss).
@ -47,7 +47,7 @@ directions, not as a help forum. Instead, direct your questions to
[Stack Overflow](https://stackoverflow.com/questions/tagged/tensorflow), and
report issues on [GitHub](https://github.com/tensorflow/tensorflow/issues).
### Report Issues <a class="md-anchor" id="AUTOGENERATED-report-issues"></a>
### Report Issues
Please report bugs, feature requests and installation / compatibility issues on
the [TensorFlow issues tracker](

View File

@ -1,4 +1,4 @@
# Example Uses <a class="md-anchor" id="AUTOGENERATED-example-uses"></a>
# Example Uses
This page describes some of the current uses of the TensorFlow system.

View File

@ -1,9 +1,9 @@
# Convolutional Neural Networks <a class="md-anchor" id="AUTOGENERATED-convolutional-neural-networks"></a>
# Convolutional Neural Networks
> **NOTE:** This tutorial is intended for *advanced* users of TensorFlow
and assumes expertise and experience in machine learning.
## Overview <a class="md-anchor" id="AUTOGENERATED-overview"></a>
## Overview
CIFAR-10 classification is a common benchmark problem in machine learning. The
problem is to classify RGB 32x32 pixel images across 10 categories:
@ -15,7 +15,7 @@ For more details refer to the [CIFAR-10 page](http://www.cs.toronto.edu/~kriz/ci
and a [Tech Report](http://www.cs.toronto.edu/~kriz/learning-features-2009-TR.pdf)
by Alex Krizhevsky.
### Goals <a class="md-anchor" id="AUTOGENERATED-goals"></a>
### Goals
The goal of this tutorial is to build a relatively small convolutional neural
network (CNN) for recognizing images. In the process, this tutorial:
@ -29,7 +29,7 @@ much of TensorFlow's ability to scale to large models. At the same time,
the model is small enough to train fast, which is ideal for trying out
new ideas and experimenting with new techniques.
### Highlights of the Tutorial <a class="md-anchor" id="AUTOGENERATED-highlights-of-the-tutorial"></a>
### Highlights of the Tutorial
The CIFAR-10 tutorial demonstrates several important constructs for
designing larger and more sophisticated models in TensorFlow:
@ -60,7 +60,7 @@ We also provide a multi-GPU version of the model which demonstrates:
We hope that this tutorial provides a launch point for building larger CNNs for
vision tasks on TensorFlow.
### Model Architecture <a class="md-anchor" id="AUTOGENERATED-model-architecture"></a>
### Model Architecture
The model in this CIFAR-10 tutorial is a multi-layer architecture consisting of
alternating convolutions and nonlinearities. These layers are followed by fully
@ -74,7 +74,7 @@ of training time on a GPU. Please see [below](#evaluating-a-model) and the code
for details. It consists of 1,068,298 learnable parameters and requires about
19.5M multiply-add operations to compute inference on a single image.
## Code Organization <a class="md-anchor" id="AUTOGENERATED-code-organization"></a>
## Code Organization
The code for this tutorial resides in
[`tensorflow/models/image/cifar10/`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/models/image/cifar10/).
@ -88,7 +88,7 @@ File | Purpose
[`cifar10_eval.py`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/models/image/cifar10/cifar10_eval.py) | Evaluates the predictive performance of a CIFAR-10 model.
## CIFAR-10 Model <a class="md-anchor" id="AUTOGENERATED-cifar-10-model"></a>
## CIFAR-10 Model
The CIFAR-10 network is largely contained in
[`cifar10.py`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/models/image/cifar10/cifar10.py).
@ -105,7 +105,7 @@ adds operations that perform inference, i.e. classification, on supplied images.
add operations that compute the loss,
gradients, variable updates and visualization summaries.
### Model Inputs <a class="md-anchor" id="model-inputs"></a>
### Model Inputs {#model-inputs}
The input part of the model is built by the functions `inputs()` and
`distorted_inputs()` which read images from the CIFAR-10 binary data files.
@ -143,7 +143,7 @@ processing time. To prevent these operations from slowing down training, we run
them inside 16 separate threads which continuously fill a TensorFlow
[queue](../../api_docs/python/io_ops.md#shuffle_batch).
### Model Prediction <a class="md-anchor" id="model-prediction"></a>
### Model Prediction {#model-prediction}
The prediction part of the model is constructed by the `inference()` function
which adds operations to compute the *logits* of the predictions. That part of
@ -182,7 +182,7 @@ layers of Alex's original model are locally connected and not fully connected.
Try editing the architecture to exactly reproduce the locally connected
architecture in the top layer.
### Model Training <a class="md-anchor" id="model-training"></a>
### Model Training {#model-training}
The usual method for training a network to perform N-way classification is
[multinomial logistic regression](https://en.wikipedia.org/wiki/Multinomial_logistic_regression),
@ -216,7 +216,7 @@ calculating the gradient and updating the learned variables (see
for details). It returns an operation that executes all the calculations
needed to train and update the model for one batch of images.
## Launching and Training the Model <a class="md-anchor" id="AUTOGENERATED-launching-and-training-the-model"></a>
## Launching and Training the Model
We have built the model, let's now launch it and run the training operation with
the script `cifar10_train.py`.
@ -301,7 +301,7 @@ values. See how the scripts use
[`ExponentialMovingAverage`](../../api_docs/python/train.md#ExponentialMovingAverage)
for this purpose.
## Evaluating a Model <a class="md-anchor" id="evaluating-a-model"></a>
## Evaluating a Model {#evaluating-a-model}
Let us now evaluate how well the trained model performs on a hold-out data set.
The model is evaluated by the script `cifar10_eval.py`. It constructs the model
@ -345,7 +345,7 @@ the averaged parameters for the model and verify that the predictive performance
drops.
## Training a Model Using Multiple GPU Cards <a class="md-anchor" id="AUTOGENERATED-training-a-model-using-multiple-gpu-cards"></a>
## Training a Model Using Multiple GPU Cards
Modern workstations may contain multiple GPUs for scientific computation.
TensorFlow can leverage this environment to run the training operation
@ -389,7 +389,7 @@ The GPUs are synchronized in operation. All gradients are accumulated from
the GPUs and averaged (see green box). The model parameters are updated with
the gradients averaged across all model replicas.
### Placing Variables and Operations on Devices <a class="md-anchor" id="AUTOGENERATED-placing-variables-and-operations-on-devices"></a>
### Placing Variables and Operations on Devices
Placing operations and variables on devices requires some special
abstractions.
@ -413,7 +413,7 @@ All variables are pinned to the CPU and accessed via
in order to share them in a multi-GPU version.
See how-to on [Sharing Variables](../../how_tos/variable_scope/index.md).
### Launching and Training the Model on Multiple GPU cards <a class="md-anchor" id="AUTOGENERATED-launching-and-training-the-model-on-multiple-gpu-cards"></a>
### Launching and Training the Model on Multiple GPU cards
If you have several GPU cards installed on your machine you can use them to
train the model faster with the `cifar10_multi_gpu_train.py` script. This
@ -444,7 +444,7 @@ you ask for more.
run on a batch size of 128. Try running `cifar10_multi_gpu_train.py` on 2 GPUs
with a batch size of 64 and compare the training speed.
## Next Steps <a class="md-anchor" id="AUTOGENERATED-next-steps"></a>
## Next Steps
[Congratulations!](https://www.youtube.com/watch?v=9bZkp7q19f0) You have
completed the CIFAR-10 tutorial.

View File

@ -1,4 +1,4 @@
# Mandelbrot Set <a class="md-anchor" id="AUTOGENERATED-mandelbrot-set"></a>
# Mandelbrot Set
Visualizing the Mandelbrot set doesn't have anything to do with machine
learning, but it makes for a fun example of how one can use TensorFlow for
@ -8,7 +8,7 @@ elaborate implementation down the line to produce more truly beautiful images.)
Note: This tutorial was originally prepared as an IPython notebook.
## Basic Setup <a class="md-anchor" id="AUTOGENERATED-basic-setup"></a>
## Basic Setup
We'll need a few imports to get started.
@ -43,7 +43,7 @@ def DisplayFractal(a, fmt='jpeg'):
display(Image(data=f.getvalue()))
```
## Session and Variable Initialization <a class="md-anchor" id="AUTOGENERATED-session-and-variable-initialization"></a>
## Session and Variable Initialization
For playing around like this, we often use an interactive session, but a regular
session would work as well.
@ -75,7 +75,7 @@ TensorFlow requires that you explicitly initialize variables before using them.
tf.initialize_all_variables().run()
```
## Defining and Running the Computation <a class="md-anchor" id="AUTOGENERATED-defining-and-running-the-computation"></a>
## Defining and Running the Computation
Now we specify more of the computation...

View File

@ -1,4 +1,4 @@
# MNIST For ML Beginners <a class="md-anchor" id="AUTOGENERATED-mnist-for-ml-beginners"></a>
# MNIST For ML Beginners
*This tutorial is intended for readers who are new to both machine learning and
TensorFlow. If you already
@ -31,7 +31,7 @@ important to understand the ideas behind it: both how TensorFlow works and the
core machine learning concepts. Because of this, we are going to very carefully
work through the code.
## The MNIST Data <a class="md-anchor" id="AUTOGENERATED-the-mnist-data"></a>
## The MNIST Data
The MNIST data is hosted on
[Yann LeCun's website](http://yann.lecun.com/exdb/mnist/). For your
@ -101,7 +101,7 @@ Consequently, `mnist.train.labels` is a
We're now ready to actually make our model!
## Softmax Regressions <a class="md-anchor" id="AUTOGENERATED-softmax-regressions"></a>
## Softmax Regressions
We know that every image in MNIST is a digit, whether it's a zero or a nine. We
want to be able to look at an image and give probabilities for it being each
@ -196,7 +196,7 @@ More compactly, we can just write:
$$y = \text{softmax}(Wx + b)$$
## Implementing the Regression <a class="md-anchor" id="AUTOGENERATED-implementing-the-regression"></a>
## Implementing the Regression
To do efficient numerical computing in Python, we typically use libraries like
@ -276,7 +276,7 @@ simulations. And once defined, our model can be run on different devices:
your computer's CPU, GPUs, and even phones!
## Training <a class="md-anchor" id="AUTOGENERATED-training"></a>
## Training
In order to train our model, we need to define what it means for the model to
be good. Well, actually, in machine learning we typically define what it means
@ -380,7 +380,7 @@ every time. Doing this is cheap and has much of the same benefit.
## Evaluating Our Model <a class="md-anchor" id="AUTOGENERATED-evaluating-our-model"></a>
## Evaluating Our Model
How well does our model do?

View File

@ -1,11 +1,11 @@
# MNIST Data Download <a class="md-anchor" id="AUTOGENERATED-mnist-data-download"></a>
# MNIST Data Download
Code: [tensorflow/g3doc/tutorials/mnist/](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/g3doc/tutorials/mnist/)
The goal of this tutorial is to show how to download the dataset files required
for handwritten digit classification using the (classic) MNIST data set.
## Tutorial Files <a class="md-anchor" id="AUTOGENERATED-tutorial-files"></a>
## Tutorial Files
This tutorial references the following files:
@ -13,7 +13,7 @@ File | Purpose
--- | ---
[`input_data.py`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/g3doc/tutorials/mnist/input_data.py) | The code to download the MNIST dataset for training and evaluation.
## Prepare the Data <a class="md-anchor" id="AUTOGENERATED-prepare-the-data"></a>
## Prepare the Data
MNIST is a classic problem in machine learning. The problem is to look at
greyscale 28x28 pixel images of handwritten digits and determine which digit
@ -24,7 +24,7 @@ the image represents, for all the digits from zero to nine.
For more information, refer to [Yann LeCun's MNIST page](http://yann.lecun.com/exdb/mnist/)
or [Chris Olah's visualizations of MNIST](http://colah.github.io/posts/2014-10-Visualizing-MNIST/).
### Download <a class="md-anchor" id="AUTOGENERATED-download"></a>
### Download
[Yann LeCun's MNIST page](http://yann.lecun.com/exdb/mnist/)
also hosts the training and test data for download.
@ -42,7 +42,7 @@ files are downloaded into a local data folder for training.
The folder name is specified in a flag variable at the top of the
`fully_connected_feed.py` file and may be changed to fit your needs.
### Unpack and Reshape <a class="md-anchor" id="AUTOGENERATED-unpack-and-reshape"></a>
### Unpack and Reshape
The files themselves are not in any standard image format and are manually
unpacked (following the instructions available at the website) by the
@ -64,7 +64,7 @@ The label data is extracted into a 1d tensor of: `[image index]`
with the class identifier for each example as the value. For the training set
labels, this would then be of shape `[55000]`.
### DataSet Object <a class="md-anchor" id="AUTOGENERATED-dataset-object"></a>
### DataSet Object
The underlying code will download, unpack, and reshape images and labels for
the following datasets:

View File

@ -1,4 +1,4 @@
# Deep MNIST for Experts <a class="md-anchor" id="AUTOGENERATED-deep-mnist-for-experts"></a>
# Deep MNIST for Experts
TensorFlow is a powerful library for doing large-scale numerical computation.
One of the tasks at which it excels is implementing and training deep neural
@ -11,12 +11,12 @@ dataset. If you don't have
a background with them, check out the
[introduction for beginners](../../../tutorials/mnist/beginners/index.md).*
## Setup <a class="md-anchor" id="AUTOGENERATED-setup"></a>
## Setup
Before we create our model, we will first load the MNIST dataset, and start a
TensorFlow session.
### Load MNIST Data <a class="md-anchor" id="AUTOGENERATED-load-mnist-data"></a>
### Load MNIST Data
For your convenience, we've included
[a script](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/g3doc/tutorials/mnist/input_data.py)
@ -33,7 +33,7 @@ testing sets as NumPy arrays.
It also provides a function for iterating through data minibatches, which we
will use below.
### Start TensorFlow InteractiveSession <a class="md-anchor" id="AUTOGENERATED-start-tensorflow-interactivesession"></a>
### Start TensorFlow InteractiveSession
Tensorflow relies on a highly efficient C++ backend to do its computation. The
connection to this backend is called a session. The common usage for TensorFlow
@ -56,7 +56,7 @@ import tensorflow as tf
sess = tf.InteractiveSession()
```
#### Computation Graph <a class="md-anchor" id="AUTOGENERATED-computation-graph"></a>
#### Computation Graph
To do efficient numerical computing in Python, we typically use libraries like
NumPy that do expensive operations such as matrix multiplication outside Python,
@ -81,13 +81,13 @@ section of
[Basic Usage](../../../get_started/basic_usage.md)
for more detail.
## Build a Softmax Regression Model <a class="md-anchor" id="AUTOGENERATED-build-a-softmax-regression-model"></a>
## Build a Softmax Regression Model
In this section we will build a softmax regression model with a single linear
layer. In the next section, we will extend this to the case of softmax
regression with a multilayer convolutional network.
### Placeholders <a class="md-anchor" id="AUTOGENERATED-placeholders"></a>
### Placeholders
We start building the computation graph by creating nodes for the
input images and target output classes.
@ -111,7 +111,7 @@ which digit class the corresponding MNIST image belongs to.
The `shape` argument to `placeholder` is optional, but it allows TensorFlow
to automatically catch bugs stemming from inconsistent tensor shapes.
### Variables <a class="md-anchor" id="AUTOGENERATED-variables"></a>
### Variables
We now define the weights `W` and biases `b` for our model. We could imagine treating
these like additional inputs, but TensorFlow has an even better way to handle
@ -140,7 +140,7 @@ done for all `Variables` at once.
sess.run(tf.initialize_all_variables())
```
### Predicted Class and Cost Function <a class="md-anchor" id="AUTOGENERATED-predicted-class-and-cost-function"></a>
### Predicted Class and Cost Function
We can now implement our regression model. It only takes one line!
We multiply the vectorized input images `x` by the weight matrix `W`, add
@ -162,7 +162,7 @@ cross_entropy = -tf.reduce_sum(y_*tf.log(y))
Note that `tf.reduce_sum` sums across all images in the minibatch, as well as
all classes. We are computing the cross entropy for the entire minibatch.
## Train the Model <a class="md-anchor" id="AUTOGENERATED-train-the-model"></a>
## Train the Model
Now that we have defined our model and training cost function, it is
straightforward to train using TensorFlow.
@ -199,7 +199,7 @@ Each training iteration we load 50 training examples. We then run the
Note that you can replace any tensor in your computation graph using `feed_dict`
-- it's not restricted to just `placeholder`s.
### Evaluate the Model <a class="md-anchor" id="AUTOGENERATED-evaluate-the-model"></a>
### Evaluate the Model
How well did our model do?
@ -229,14 +229,14 @@ Finally, we can evaluate our accuracy on the test data. This should be about
print accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels})
```
## Build a Multilayer Convolutional Network <a class="md-anchor" id="AUTOGENERATED-build-a-multilayer-convolutional-network"></a>
## Build a Multilayer Convolutional Network
Getting 91% accuracy on MNIST is bad. It's almost embarrassingly bad. In this
section, we'll fix that, jumping from a very simple model to something
moderately sophisticated: a small convolutional neural network. This will get us
to around 99.2% accuracy -- not state of the art, but respectable.
### Weight Initialization <a class="md-anchor" id="AUTOGENERATED-weight-initialization"></a>
### Weight Initialization
To create this model, we're going to need to create a lot of weights and biases.
One should generally initialize weights with a small amount of noise for
@ -255,7 +255,7 @@ def bias_variable(shape):
return tf.Variable(initial)
```
### Convolution and Pooling <a class="md-anchor" id="AUTOGENERATED-convolution-and-pooling"></a>
### Convolution and Pooling
TensorFlow also gives us a lot of flexibility in convolution and pooling
operations. How do we handle the boundaries? What is our stride size?
@ -274,7 +274,7 @@ def max_pool_2x2(x):
strides=[1, 2, 2, 1], padding='SAME')
```
### First Convolutional Layer <a class="md-anchor" id="AUTOGENERATED-first-convolutional-layer"></a>
### First Convolutional Layer
We can now implement our first layer. It will consist of convolution, followed
by max pooling. The convolutional will compute 32 features for each 5x5 patch.
@ -304,7 +304,7 @@ h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
h_pool1 = max_pool_2x2(h_conv1)
```
### Second Convolutional Layer <a class="md-anchor" id="AUTOGENERATED-second-convolutional-layer"></a>
### Second Convolutional Layer
In order to build a deep network, we stack several layers of this type. The
second layer will have 64 features for each 5x5 patch.
@ -317,7 +317,7 @@ h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2)
```
### Densely Connected Layer <a class="md-anchor" id="AUTOGENERATED-densely-connected-layer"></a>
### Densely Connected Layer
Now that the image size has been reduced to 7x7, we add a fully-connected layer
with 1024 neurons to allow processing on the entire image. We reshape the tensor
@ -332,7 +332,7 @@ h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
```
#### Dropout <a class="md-anchor" id="AUTOGENERATED-dropout"></a>
#### Dropout
To reduce overfitting, we will apply dropout before the readout layer.
We create a `placeholder` for the probability that a neuron's output is kept
@ -346,7 +346,7 @@ keep_prob = tf.placeholder("float")
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)
```
### Readout Layer <a class="md-anchor" id="AUTOGENERATED-readout-layer"></a>
### Readout Layer
Finally, we add a softmax layer, just like for the one layer softmax regression
above.
@ -358,7 +358,7 @@ b_fc2 = bias_variable([10])
y_conv=tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)
```
### Train and Evaluate the Model <a class="md-anchor" id="AUTOGENERATED-train-and-evaluate-the-model"></a>
### Train and Evaluate the Model
How well does this model do?
To train and evaluate it we will use code that is nearly identical to that for

View File

@ -1,4 +1,4 @@
# TensorFlow Mechanics 101 <a class="md-anchor" id="AUTOGENERATED-tensorflow-mechanics-101"></a>
# TensorFlow Mechanics 101
Code: [tensorflow/g3doc/tutorials/mnist/](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/g3doc/tutorials/mnist/)
@ -12,7 +12,7 @@ These tutorials are not intended for teaching Machine Learning in general.
Please ensure you have followed the instructions to [install TensorFlow](../../../get_started/os_setup.md).
## Tutorial Files <a class="md-anchor" id="AUTOGENERATED-tutorial-files"></a>
## Tutorial Files
This tutorial references the following files:
@ -25,7 +25,7 @@ Simply run the `fully_connected_feed.py` file directly to start training:
`python fully_connected_feed.py`
## Prepare the Data <a class="md-anchor" id="AUTOGENERATED-prepare-the-data"></a>
## Prepare the Data
MNIST is a classic problem in machine learning. The problem is to look at
greyscale 28x28 pixel images of handwritten digits and determine which digit
@ -36,7 +36,7 @@ the image represents, for all the digits from zero to nine.
For more information, refer to [Yann LeCun's MNIST page](http://yann.lecun.com/exdb/mnist/)
or [Chris Olah's visualizations of MNIST](http://colah.github.io/posts/2014-10-Visualizing-MNIST/).
### Download <a class="md-anchor" id="AUTOGENERATED-download"></a>
### Download
At the top of the `run_training()` method, the `input_data.read_data_sets()`
function will ensure that the correct data has been downloaded to your local
@ -59,7 +59,7 @@ Dataset | Purpose
For more information about the data, please read the [Download](../../../tutorials/mnist/download/index.md)
tutorial.
### Inputs and Placeholders <a class="md-anchor" id="AUTOGENERATED-inputs-and-placeholders"></a>
### Inputs and Placeholders
The `placeholder_inputs()` function creates two [`tf.placeholder`](../../../api_docs/python/io_ops.md#placeholder)
ops that define the shape of the inputs, including the `batch_size`, to the
@ -76,7 +76,7 @@ sliced to fit the `batch_size` for each step, matched with these placeholder
ops, and then passed into the `sess.run()` function using the `feed_dict`
parameter.
## Build the Graph <a class="md-anchor" id="AUTOGENERATED-build-the-graph"></a>
## Build the Graph
After creating placeholders for the data, the graph is built from the
`mnist.py` file according to a 3-stage pattern: `inference()`, `loss()`, and
@ -93,7 +93,7 @@ and apply gradients.
<img style="width:100%" src="./mnist_subgraph.png">
</div>
### Inference <a class="md-anchor" id="AUTOGENERATED-inference"></a>
### Inference
The `inference()` function builds the graph as far as needed to
return the tensor that would contain the output predictions.
@ -162,7 +162,7 @@ logits = tf.matmul(hidden2, weights) + biases
Finally, the `logits` tensor that will contain the output is returned.
### Loss <a class="md-anchor" id="AUTOGENERATED-loss"></a>
### Loss
The `loss()` function further builds the graph by adding the required loss
ops.
@ -205,7 +205,7 @@ And the tensor that will then contain the loss value is returned.
> given what is actually true. For more information, read the blog post Visual
> Information Theory (http://colah.github.io/posts/2015-09-Visual-Information/)
### Training <a class="md-anchor" id="AUTOGENERATED-training"></a>
### Training
The `training()` function adds the operations needed to minimize the loss via
gradient descent.
@ -241,12 +241,12 @@ train_op = optimizer.minimize(loss, global_step=global_step)
The tensor containing the outputs of the training op is returned.
## Train the Model <a class="md-anchor" id="AUTOGENERATED-train-the-model"></a>
## Train the Model
Once the graph is built, it can be iteratively trained and evaluated in a loop
controlled by the user code in `fully_connected_feed.py`.
### The Graph <a class="md-anchor" id="AUTOGENERATED-the-graph"></a>
### The Graph
At the top of the `run_training()` function is a python `with` command that
indicates all of the built ops are to be associated with the default
@ -263,7 +263,7 @@ Most TensorFlow uses will only need to rely on the single default graph.
More complicated uses with multiple graphs are possible, but beyond the scope of
this simple tutorial.
### The Session <a class="md-anchor" id="AUTOGENERATED-the-session"></a>
### The Session
Once all of the build preparation has been completed and all of the necessary
ops generated, a [`tf.Session`](../../../api_docs/python/client.md#Session)
@ -298,7 +298,7 @@ op is a [`tf.group`](../../../api_docs/python/control_flow_ops.md#group)
that contains only the initializers for the variables. None of the rest of the
graph is run here; that happens in the training loop below.
### Train Loop <a class="md-anchor" id="AUTOGENERATED-train-loop"></a>
### Train Loop
After initializing the variables with the session, training may begin.
@ -313,7 +313,7 @@ for step in xrange(max_steps):
However, this tutorial is slightly more complicated in that it must also slice
up the input data for each step to match the previously generated placeholders.
#### Feed the Graph <a class="md-anchor" id="AUTOGENERATED-feed-the-graph"></a>
#### Feed the Graph
For each step, the code will generate a feed dictionary that will contain the
set of examples on which to train for the step, keyed by the placeholder
@ -340,7 +340,7 @@ feed_dict = {
This is passed into the `sess.run()` function's `feed_dict` parameter to provide
the input examples for this step of training.
#### Check the Status <a class="md-anchor" id="AUTOGENERATED-check-the-status"></a>
#### Check the Status
The code specifies two values to fetch in its run call: `[train_op, loss]`.
@ -370,7 +370,7 @@ if step % 100 == 0:
print 'Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration)
```
#### Visualize the Status <a class="md-anchor" id="AUTOGENERATED-visualize-the-status"></a>
#### Visualize the Status
In order to emit the events files used by [TensorBoard](../../../how_tos/summaries_and_tensorboard/index.md),
all of the summaries (in this case, only one) are collected into a single op
@ -405,7 +405,7 @@ folder to display the values from the summaries.
**NOTE**: For more info about how to build and run Tensorboard, please see the accompanying tutorial [Tensorboard: Visualizing Your Training](../../../how_tos/summaries_and_tensorboard/index.md).
#### Save a Checkpoint <a class="md-anchor" id="AUTOGENERATED-save-a-checkpoint"></a>
#### Save a Checkpoint
In order to emit a checkpoint file that may be used to later restore a model
for further training or evaluation, we instantiate a
@ -431,7 +431,7 @@ method to reload the model parameters.
saver.restore(sess, FLAGS.train_dir)
```
## Evaluate the Model <a class="md-anchor" id="AUTOGENERATED-evaluate-the-model"></a>
## Evaluate the Model
Every thousand steps, the code will attempt to evaluate the model against both
the training and test datasets. The `do_eval()` function is called thrice, for
@ -463,7 +463,7 @@ do_eval(sess,
> the sake of a simple little MNIST problem, however, we evaluate against all of
> the data.
### Build the Eval Graph <a class="md-anchor" id="AUTOGENERATED-build-the-eval-graph"></a>
### Build the Eval Graph
Before opening the default Graph, the test data should have been fetched by
calling the `get_data(train=False)` function with the parameter set to grab
@ -490,7 +490,7 @@ of K to 1 to only consider a prediction correct if it is for the true label.
eval_correct = tf.nn.in_top_k(logits, labels, 1)
```
### Eval Output <a class="md-anchor" id="AUTOGENERATED-eval-output"></a>
### Eval Output
One can then create a loop for filling a `feed_dict` and calling `sess.run()`
against the `eval_correct` op to evaluate the model on the given dataset.

View File

@ -1,4 +1,4 @@
# Partial Differential Equations <a class="md-anchor" id="AUTOGENERATED-partial-differential-equations"></a>
# Partial Differential Equations
TensorFlow isn't just for machine learning. Here we give a (somewhat
pedestrian) example of using TensorFlow for simulating the behavior of a
@ -7,7 +7,7 @@ few raindrops land on it.
Note: This tutorial was originally prepared as an IPython notebook.
## Basic Setup <a class="md-anchor" id="AUTOGENERATED-basic-setup"></a>
## Basic Setup
A few imports we'll need.
@ -42,7 +42,7 @@ executable .py file.
sess = tf.InteractiveSession()
```
## Computational Convenience Functions <a class="md-anchor" id="AUTOGENERATED-computational-convenience-functions"></a>
## Computational Convenience Functions
```python
@ -66,7 +66,7 @@ def laplace(x):
return simple_conv(x, laplace_k)
```
## Define the PDE <a class="md-anchor" id="AUTOGENERATED-define-the-pde"></a>
## Define the PDE
Our pond is a perfect 500 x 500 square, as is the case for most ponds found in
nature.
@ -119,7 +119,7 @@ step = tf.group(
Ut.assign(Ut_))
```
## Run The Simulation <a class="md-anchor" id="AUTOGENERATED-run-the-simulation"></a>
## Run The Simulation
This is where it gets fun -- running time forward with a simple for loop.

View File

@ -1,12 +1,12 @@
# Recurrent Neural Networks <a class="md-anchor" id="AUTOGENERATED-recurrent-neural-networks"></a>
# Recurrent Neural Networks
## Introduction <a class="md-anchor" id="AUTOGENERATED-introduction"></a>
## Introduction
Take a look at [this great article]
(http://colah.github.io/posts/2015-08-Understanding-LSTMs/)
for an introduction to recurrent neural networks and LSTMs in particular.
## Language Modeling <a class="md-anchor" id="AUTOGENERATED-language-modeling"></a>
## Language Modeling
In this tutorial we will show how to train a recurrent neural network on
a challenging task of language modeling. The goal of the problem is to fit a
@ -24,7 +24,7 @@ For the purpose of this tutorial, we will reproduce the results from
[Zaremba et al., 2014] (http://arxiv.org/abs/1409.2329), which achieves very
good results on the PTB dataset.
## Tutorial Files <a class="md-anchor" id="AUTOGENERATED-tutorial-files"></a>
## Tutorial Files
This tutorial references the following files from `models/rnn/ptb`:
@ -33,7 +33,7 @@ File | Purpose
`ptb_word_lm.py` | The code to train a language model on the PTB dataset.
`reader.py` | The code to read the dataset.
## Download and Prepare the Data <a class="md-anchor" id="AUTOGENERATED-download-and-prepare-the-data"></a>
## Download and Prepare the Data
The data required for this tutorial is in the data/ directory of the
PTB dataset from Tomas Mikolov's webpage:
@ -44,9 +44,9 @@ including the end-of-sentence marker and a special symbol (\<unk\>) for rare
words. We convert all of them in the `reader.py` to unique integer identifiers
to make it easy for the neural network to process.
## The Model <a class="md-anchor" id="AUTOGENERATED-the-model"></a>
## The Model
### LSTM <a class="md-anchor" id="AUTOGENERATED-lstm"></a>
### LSTM
The core of the model consists of an LSTM cell that processes one word at the
time and computes probabilities of the possible continuations of the sentence.
@ -72,7 +72,7 @@ for current_batch_of_words in words_in_dataset:
loss += loss_function(probabilities, target_words)
```
### Truncated Backpropagation <a class="md-anchor" id="AUTOGENERATED-truncated-backpropagation"></a>
### Truncated Backpropagation
In order to make the learning process tractable, it is a common practice to
truncate the gradients for backpropagation to a fixed number (`num_steps`)
@ -114,7 +114,7 @@ for current_batch_of_words in words_in_dataset:
total_loss += current_loss
```
### Inputs <a class="md-anchor" id="AUTOGENERATED-inputs"></a>
### Inputs
The word IDs will be embedded into a dense representation (see the
[Vector Representations Tutorial](../../tutorials/word2vec/index.md)) before feeding to
@ -129,7 +129,7 @@ word_embeddings = tf.nn.embedding_lookup(embedding_matrix, word_ids)
The embedding matrix will be initialized randomly and the model will learn to
differentiate the meaning of words just by looking at the data.
### Loss Fuction <a class="md-anchor" id="AUTOGENERATED-loss-fuction"></a>
### Loss Fuction
We want to minimize the average negative log probability of the target words:
@ -145,7 +145,7 @@ $$e^{-\frac{1}{N}\sum_{i=1}^{N} \ln p_{\text{target}_i}} = e^{\text{loss}} $$
and we will monitor its value throughout the training process.
### Stacking multiple LSTMs <a class="md-anchor" id="AUTOGENERATED-stacking-multiple-lstms"></a>
### Stacking multiple LSTMs
To give the model more expressive power, we can add multiple layers of LSTMs
to process the data. The output of the first layer will become the input of
@ -168,7 +168,7 @@ for i in range(len(num_steps)):
final_state = state
```
## Compile and Run the Code <a class="md-anchor" id="AUTOGENERATED-compile-and-run-the-code"></a>
## Compile and Run the Code
First, the library needs to be built. To compile it on CPU:
@ -197,7 +197,7 @@ The larger the model, the better results it should get. The `small` model should
be able to reach perplexity below 120 on the test set and the `large` one below
80, though it might take several hours to train.
## What Next? <a class="md-anchor" id="AUTOGENERATED-what-next-"></a>
## What Next?
There are several tricks that we haven't mentioned that make the model better,
including:

View File

@ -1,4 +1,4 @@
# Sequence-to-Sequence Models <a class="md-anchor" id="AUTOGENERATED-sequence-to-sequence-models"></a>
# Sequence-to-Sequence Models
Recurrent neural networks can learn to model language, as already discussed
in the [RNN Tutorial](../../tutorials/recurrent/index.md)
@ -32,7 +32,7 @@ File | What's in it?
`translate/translate.py` | Binary that trains and runs the translation model.
## Sequence-to-Sequence Basics <a class="md-anchor" id="AUTOGENERATED-sequence-to-sequence-basics"></a>
## Sequence-to-Sequence Basics
A basic sequence-to-sequence model, as introduced in
[Cho et al., 2014](http://arxiv.org/pdf/1406.1078v3.pdf),
@ -64,7 +64,7 @@ attention mechanism in the decoder looks like this.
<img style="width:100%" src="attention_seq2seq.png" />
</div>
## TensorFlow seq2seq Library <a class="md-anchor" id="AUTOGENERATED-tensorflow-seq2seq-library"></a>
## TensorFlow seq2seq Library
As you can see above, there are many different sequence-to-sequence
models. Each of these models can use different RNN cells, but all
@ -141,14 +141,14 @@ more sequence-to-sequence models in `seq2seq.py`, take a look there. They all
have similar interfaces, so we will not describe them in detail. We will use
`embedding_attention_seq2seq` for our translation model below.
## Neural Translation Model <a class="md-anchor" id="AUTOGENERATED-neural-translation-model"></a>
## Neural Translation Model
While the core of the sequence-to-sequence model is constructed by
the functions in `models/rnn/seq2seq.py`, there are still a few tricks
that are worth mentioning that are used in our translation model in
`models/rnn/translate/seq2seq_model.py`.
### Sampled softmax and output projection <a class="md-anchor" id="AUTOGENERATED-sampled-softmax-and-output-projection"></a>
### Sampled softmax and output projection
For one, as already mentioned above, we want to use sampled softmax to
handle large output vocabulary. To decode from it, we need to keep track
@ -184,7 +184,7 @@ if output_projection is not None:
output_projection[1] for ...]
```
### Bucketing and padding <a class="md-anchor" id="AUTOGENERATED-bucketing-and-padding"></a>
### Bucketing and padding
In addition to sampled softmax, our translation model also makes use
of *bucketing*, which is a method to efficiently handle sentences of
@ -230,7 +230,7 @@ with encoder inputs representing `[PAD PAD "." "go" "I"]` and decoder
inputs `[GO "Je" "vais" "." EOS PAD PAD PAD PAD PAD]`.
## Let's Run It <a class="md-anchor" id="run_it"></a>
## Let's Run It {#run_it}
To train the model described above, we need to a large English-French corpus.
We will use the *10^9-French-English corpus* from the
@ -304,7 +304,7 @@ Reading model parameters from /tmp/translate.ckpt-340000
Qui est le président des États-Unis ?
```
## What Next? <a class="md-anchor" id="AUTOGENERATED-what-next-"></a>
## What Next?
The example above shows how you can build your own English-to-French
translator, end-to-end. Run it and see how the model performs for yourself.

View File

@ -1,11 +1,11 @@
# Vector Representations of Words <a class="md-anchor" id="AUTOGENERATED-vector-representations-of-words"></a>
# Vector Representations of Words
In this tutorial we look at the word2vec model by
[Mikolov et al.](http://papers.nips.cc/paper/5021-distributed-representations-of-words-and-phrases-and-their-compositionality.pdf)
This model is used for learning vector representations of words, called "word
embeddings".
## Highlights <a class="md-anchor" id="AUTOGENERATED-highlights"></a>
## Highlights
This tutorial is meant to highlight the interesting, substantive parts of
building a word2vec model in TensorFlow.
@ -32,7 +32,7 @@ But first, let's look at why we would want to learn word embeddings in the first
place. Feel free to skip this section if you're an Embedding Pro and you'd just
like to get your hands dirty with the details.
## Motivation: Why Learn Word Embeddings? <a class="md-anchor" id="AUTOGENERATED-motivation--why-learn-word-embeddings-"></a>
## Motivation: Why Learn Word Embeddings?
Image and audio processing systems work with rich, high-dimensional datasets
encoded as vectors of the individual raw pixel-intensities for image data, or
@ -90,7 +90,7 @@ pair as a new observation, and this tends to do better when we have larger
datasets. We will focus on the skip-gram model in the rest of this tutorial.
## Scaling up with Noise-Contrastive Training <a class="md-anchor" id="AUTOGENERATED-scaling-up-with-noise-contrastive-training"></a>
## Scaling up with Noise-Contrastive Training
Neural probabilistic language models are traditionally trained using the
[maximum likelihood](https://en.wikipedia.org/wiki/Maximum_likelihood) (ML)
@ -166,7 +166,7 @@ loss, for which TensorFlow has a handy helper function `tf.nn.nce_loss()`.
Let's get an intuitive feel for how this would work in practice!
## The Skip-gram Model <a class="md-anchor" id="AUTOGENERATED-the-skip-gram-model"></a>
## The Skip-gram Model
As an example, let's consider the dataset
@ -243,7 +243,7 @@ NLP prediction tasks, such as part-of-speech tagging or named entity recognition
But for now, let's just use them to draw pretty pictures!
## Building the Graph <a class="md-anchor" id="AUTOGENERATED-building-the-graph"></a>
## Building the Graph
This is all about embeddings, so let's define our embedding matrix.
This is just a big random matrix to start. We'll initialize the values to be
@ -307,7 +307,7 @@ gradient descent, and TensorFlow has handy helpers to make this easy as well.
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0).minimize(loss)
```
## Training the Model <a class="md-anchor" id="AUTOGENERATED-training-the-model"></a>
## Training the Model
Training the model is then as simple as using a `feed_dict` to push data into
the placeholders and calling
@ -323,7 +323,7 @@ for inputs, labels in generate_batch(...):
See the full example code in
[tensorflow/g3doc/tutorials/word2vec/word2vec_basic.py](./word2vec_basic.py).
## Visualizing the Learned Embeddings <a class="md-anchor" id="AUTOGENERATED-visualizing-the-learned-embeddings"></a>
## Visualizing the Learned Embeddings
After training has finished we can visualize the learned embeddings using
t-SNE.
@ -337,7 +337,7 @@ other. For a more heavyweight implementation of word2vec that showcases more of
the advanced features of TensorFlow, see the implementation in
[tensorflow/models/embedding/word2vec.py](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/models/embedding/word2vec.py).
## Evaluating Embeddings: Analogical Reasoning <a class="md-anchor" id="AUTOGENERATED-evaluating-embeddings--analogical-reasoning"></a>
## Evaluating Embeddings: Analogical Reasoning
Embeddings are useful for a wide variety of prediction tasks in NLP. Short of
training a full-blown part-of-speech model or named-entity model, one simple way
@ -358,7 +358,7 @@ very large dataset, carefully tuning the hyperparameters and making use of
tricks like subsampling the data, which is out of the scope of this tutorial.
## Optimizing the Implementation <a class="md-anchor" id="AUTOGENERATED-optimizing-the-implementation"></a>
## Optimizing the Implementation
Our vanilla implementation showcases the flexibility of TensorFlow. For
example, changing the training objective is as simple as swapping out the call
@ -388,7 +388,7 @@ example of this for the Skip-Gram case
Feel free to benchmark these against each other to measure performance
improvements at each stage.
## Conclusion <a class="md-anchor" id="AUTOGENERATED-conclusion"></a>
## Conclusion
In this tutorial we covered the word2vec model, a computationally efficient
model for learning word embeddings. We motivated why embeddings are useful,

View File

@ -687,6 +687,29 @@ class ControlFlowTest(tf.test.TestCase):
self._testWhile_Gpu_1(use_gpu=False)
self._testWhile_Gpu_1(use_gpu=True)
def _testWhileNested_1(self, use_gpu):
with self.test_session(use_gpu=use_gpu):
n = tf.constant(0)
def cpu_sum(s):
c = lambda i, s: tf.less(i, 10)
def b(i, s):
i1 = tf.add(i, 1)
with tf.device("/cpu:0"):
s1 = tf.add(i, s)
return i1, s1
_, r_s = control_flow_ops.While(c, b, [n, s])
return r_s
c = lambda x: tf.less(x, 200)
b = lambda x: tf.add(x, cpu_sum(n))
r = control_flow_ops.While(c, b, [n])
result = r.eval()
self.assertEqual(225, result)
def testWhileNested_1(self):
self._testWhileNested_1(use_gpu=False)
self._testWhileNested_1(use_gpu=True)
def testWhileWithControl_1(self):
with self.test_session():
n = tf.constant(0)

View File

@ -35,6 +35,11 @@ class InTopKTest(tf.test.TestCase):
target = [2, 3]
self._validateInTopK(predictions, target, 2, [True, True])
def testInTop2_int64Target(self):
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
target = np.asarray([0, 2]).astype(np.int64)
self._validateInTopK(predictions, target, 2, [False, True])
if __name__ == "__main__":
tf.test.main()

View File

@ -14,33 +14,76 @@ from tensorflow.python.kernel_tests import gradient_checker as gc
class MatrixInverseGradientTest(tf.test.TestCase):
pass # Filled in below
def _GetMatrixInverseGradientTest(dtype, shape):
def _GetMatrixInverseGradientTest(dtype_, shape_):
def Test(self):
with self.test_session():
np.random.seed(1)
m = np.random.uniform(low=1.0, high=100.0, size=np.prod(shape)).reshape(
shape).astype(dtype)
m = np.random.uniform(low=1.0,
high=100.0,
size=np.prod(shape_)).reshape(shape_).astype(dtype_)
a = tf.constant(m)
epsilon = np.finfo(dtype).eps
epsilon = np.finfo(dtype_).eps
# Optimal stepsize for central difference is O(epsilon^{1/3}).
delta = epsilon ** (1.0 / 3.0)
delta = epsilon**(1.0 / 3.0)
tol = 1e-3
if len(shape) == 2:
if len(shape_) == 2:
ainv = tf.matrix_inverse(a)
else:
ainv = tf.batch_matrix_inverse(a)
theoretical, numerical = gc.ComputeGradient(a, shape, ainv, shape,
theoretical, numerical = gc.ComputeGradient(a,
shape_,
ainv,
shape_,
delta=delta)
self.assertAllClose(theoretical, numerical, atol=tol, rtol=tol)
return Test
if __name__ == "__main__":
class MatrixDeterminantGradientTest(tf.test.TestCase):
pass # Filled in below
def _GetMatrixDeterminantGradientTest(dtype_, shape_):
def Test(self):
with self.test_session():
np.random.seed(1)
m = np.random.uniform(low=1.0,
high=100.0,
size=np.prod(shape_)).reshape(shape_).astype(dtype_)
a = tf.constant(m)
epsilon = np.finfo(dtype_).eps
# Optimal stepsize for central difference is O(epsilon^{1/3}).
delta = epsilon**(1.0 / 3.0)
# tolerance obtained by looking at actual differences using
# np.linalg.norm(theoretical-numerical, np.inf) on -mavx build
tol = 1e-3
if len(shape_) == 2:
c = tf.matrix_determinant(a)
else:
c = tf.batch_matrix_determinant(a)
out_shape = shape_[:-2] # last two dimensions hold matrices
theoretical, numerical = gc.ComputeGradient(a, shape_, c, out_shape,
delta=delta)
self.assertAllClose(theoretical, numerical, atol=tol, rtol=tol)
return Test
if __name__ == '__main__':
# TODO(rmlarsen,irving): Reenable float32 once tolerances are fixed
# The test used to loop over (np.float, np.double), both of which are float64.
for dtype in np.float64,:
for dtype in (np.float64,):
for size in 2, 3, 5, 10:
# We skip the rank 4, size 10 case: it is slow and conceptually covered
# by the other cases.
@ -49,4 +92,14 @@ if __name__ == "__main__":
name = '%s_%s' % (dtype.__name__, '_'.join(map(str, shape)))
setattr(MatrixInverseGradientTest, 'testMatrixInverseGradient_' + name,
_GetMatrixInverseGradientTest(dtype, shape))
for dtype in (np.float64,):
for size in 2, 5, 10:
# increase this list to check batch version
for extra in [()]:
shape = extra+(size, size)
name = '%s_%s' % (dtype.__name__, '_'.join(map(str, shape)))
setattr(MatrixDeterminantGradientTest,
'testMatrixDeterminantGradient_' + name,
_GetMatrixDeterminantGradientTest(dtype, shape))
tf.test.main()

View File

@ -47,6 +47,7 @@ debug your graph.
@@Assert
@@Print
"""
# pylint: disable=g-bad-name
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@ -341,14 +342,6 @@ class ControlFlowOpWrapper(object):
"""
return self._op.device
@property
def output_types(self):
return self._op.output_types
@property
def input_types(self):
return self._op._input_types
@property
def type(self):
"""Returns the type of the op."""
@ -356,12 +349,12 @@ class ControlFlowOpWrapper(object):
@property
def graph(self):
"""Returns the parent graph."""
"""The `Graph` that contains this operation."""
return self._op.graph
def GetAttr(self, attr_name):
"""Returns the value of attribute 'attr_name' of NodeDef."""
return self._op.get_attr(attr_name)
def get_attr(self, name):
"""Returns the value of the attr of this op with the given `name`."""
return self._op.get_attr(name)
def _get_control_flow_context(self):
return self._op._get_control_flow_context()
@ -625,7 +618,7 @@ def cond(pred, fn1, fn2, name=None):
# r is set to f1()
```
"""
with ops.op_scope([pred], name, "Cond") as name:
with ops.op_scope([pred], name, "cond") as name:
if not callable(fn1):
raise TypeError("fn1 must be callable.")
if not callable(fn2):
@ -1316,7 +1309,7 @@ def fold(fn, elems, elem_shape, name=None):
Raises:
TypeError: if `fn` is not callable.
"""
with ops.op_scope([elems], name, "Fold") as name:
with ops.op_scope([elems], name, "fold") as name:
if not callable(fn):
raise TypeError("fn must be callable.")
@ -1340,8 +1333,8 @@ def fold(fn, elems, elem_shape, name=None):
return r[1]
def case(pred_fn_pairs, default, exclusive=False, name="Case"):
"""Create a Case operation.
def case(pred_fn_pairs, default, exclusive=False, name="case"):
"""Create a case operation.
The `pred_fn_pairs` parameter is a dict or list of pairs of size N.
Each pair contains a boolean scalar tensor and a python callable that
@ -1366,9 +1359,9 @@ def case(pred_fn_pairs, default, exclusive=False, name="Case"):
Expressions:
```
f1 = lambda: Constant(17)
f2 = lambda: Constant(23)
r = Case([(math_ops.less(x, y), f1)], default=f2)
f1 = lambda: tf.onstant(17)
f2 = lambda: tf.constant(23)
r = case([(tf.less(x, y), f1)], default=f2)
```
Example 2:
@ -1382,10 +1375,10 @@ def case(pred_fn_pairs, default, exclusive=False, name="Case"):
Expressions:
```
def f1(): return Constant(17)
def f2(): return Constant(23)
def f3(): return Constant(-1)
r = Case({math_ops.less(x, y): f1, math_ops.greater(x, z): f2},
def f1(): return tf.constant(17)
def f2(): return tf.constant(23)
def f3(): return tf.constant(-1)
r = case({tf.less(x, y): f1, tf.greater(x, z): f2},
default=f3, exclusive=True)
```
@ -1428,7 +1421,7 @@ def case(pred_fn_pairs, default, exclusive=False, name="Case"):
raise TypeError("default must be callable.")
preds, fns = map(list, zip(*pfp))
with ops.op_scope([[f() for f in fns] + preds + [default()]], name, "Case"):
with ops.op_scope([[f() for f in fns] + preds + [default()]], name, "case"):
if not preds:
return default()
not_preds = []
@ -1451,10 +1444,10 @@ def case(pred_fn_pairs, default, exclusive=False, name="Case"):
with ops.name_scope("case_%d" % i):
case_preds.append(math_ops.logical_and(p, and_not_p_prev))
# case_sequence = [Cond(p3 & ..., f3, default),
# Cond(p2 & ..., f2, lambda: case_sequence[0]),
# case_sequence = [cond(p3 & ..., f3, default),
# cond(p2 & ..., f2, lambda: case_sequence[0]),
# ...
# Cond(p1 & True, f1, lambda: case_sequence[i-1])]
# cond(p1 & True, f1, lambda: case_sequence[i-1])]
# and prev_case_seq will loop from case_sequence[0] to case_sequence[-1]
if exclusive:
# TODO(ebrevdo): Add Where() for DT_BOOL, replace with Size(Where(preds))

View File

@ -1,4 +1,10 @@
"""Gradients for operators defined in linalg_ops.py."""
"""Gradients for operators defined in linalg_ops.py.
Useful reference for derivative formulas is
An extended collection of matrix derivative results for forward and reverse
mode algorithmic differentiation by Mike Giles:
http://eprints.maths.ox.ac.uk/1079/1/NA-08-01.pdf
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@ -10,20 +16,40 @@ from tensorflow.python.ops import constant_op
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
@ops.RegisterGradient("MatrixInverse")
def _MatrixInverseGrad(op, grad):
"""Gradient for MatrixInverse."""
ainv = op.outputs[0]
return -math_ops.matmul(
ainv,
math_ops.matmul(grad, ainv, transpose_b=True),
transpose_a=True)
return -math_ops.matmul(ainv,
math_ops.matmul(grad,
ainv,
transpose_b=True),
transpose_a=True)
@ops.RegisterGradient("BatchMatrixInverse")
def _BatchMatrixInverseGrad(op, grad):
"""Gradient for BatchMatrixInverse."""
ainv = op.outputs[0]
return -math_ops.batch_matmul(
ainv,
math_ops.batch_matmul(grad, ainv, adj_y=True),
adj_x=True)
return -math_ops.batch_matmul(ainv,
math_ops.batch_matmul(grad,
ainv,
adj_y=True),
adj_x=True)
@ops.RegisterGradient("MatrixDeterminant")
def _MatrixDeterminantGrad(op, grad):
"""Gradient for MatrixDeterminant.
Returns:
gradient
Args:
op: op
grad: grad
"""
a = op.inputs[0]
c = op.outputs[0]
ainv = linalg_ops.matrix_inverse(a)
return grad * c * array_ops.transpose(ainv)

View File

@ -28,8 +28,8 @@ def histogram_summary(tag, values, collections=None, name=None):
Args:
tag: A `string` `Tensor`. 0-D. Tag to use for the summary value.
values: A `float32` `Tensor`. Any shape. Values to use to build the
histogram.
values: A `float32` or `float64` `Tensor`. Any shape. Values to use to
build the histogram.
collections: Optional list of graph collections keys. The new summary op is
added to these collections. Defaults to `[GraphKeys.SUMMARIES]`.
name: A name for the operation (optional).

View File

@ -195,13 +195,21 @@ class ExponentialMovingAverage(object):
if var in self._averages:
raise ValueError("Moving average already computed for: %s" % var)
with ops.name_scope(var.op.name + "/" + self._name) as scope:
with ops.device(var.device):
if isinstance(var, variables.Variable):
initial_value = var.initialized_value()
else:
initial_value = array_ops.zeros(var.get_shape().as_list())
avg = variables.Variable(initial_value, name=scope, trainable=False)
self._averages[var] = avg
# For variables: to lower communication bandwidth across devices we keep
# the moving averages on the same device as the variables. For other
# tensors, we rely on the existing device allocation mechanism.
if isinstance(var, variables.Variable):
with ops.device(var.device):
avg = variables.Variable(var.initialized_value(),
name=scope, trainable=False)
elif var.op.type == "Variable":
with ops.device(var.device):
avg = variables.Variable(array_ops.zeros(var.get_shape().as_list()),
name=scope, trainable=False)
else:
avg = variables.Variable(array_ops.zeros(var.get_shape().as_list()),
name=scope, trainable=False)
self._averages[var] = avg
with ops.name_scope(self._name) as scope:
decay = ops.convert_to_tensor(self._decay, name="decay")
if self._num_updates is not None:

View File

@ -6,9 +6,11 @@ from __future__ import print_function
import tensorflow.python.platform
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.framework import types
from tensorflow.python.ops import constant_op
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
from tensorflow.python.training import moving_averages
@ -130,6 +132,19 @@ class ExponentialMovingAverageTest(test_util.TensorFlowTestCase):
self.assertEqual(ema.average_name(v1), ema.average(v1).op.name)
self.assertEqual(ema.average_name(tensor2), ema.average(tensor2).op.name)
def testAverageVariablesDeviceAssignment(self):
with ops.device("dev_v0"):
v0 = variables.Variable(10.0, name="v0")
with ops.device("dev_v1"):
v1 = state_ops.variable_op(shape=[1], dtype=types.float32, name="v1")
tensor2 = v0 + v1
ema = moving_averages.ExponentialMovingAverage(0.25, name="foo_avg")
with ops.device("default"):
ema.apply([v0, v1, tensor2])
self.assertEqual("dev_v0", ema.average(v0).device)
self.assertEqual("dev_v1", ema.average(v1).device)
self.assertEqual("default", ema.average(tensor2).device)
if __name__ == "__main__":
googletest.main()

View File

@ -43,7 +43,7 @@
},
"devDependencies": {
"iron-component-page": "PolymerElements/iron-component-page#^1.0.0",
"web-component-tester": "*"
"web-component-tester": "Polymer/web-component-tester"
},
"resolutions": {
"d3": "3.5.6"

Some files were not shown because too many files have changed in this diff Show More