TensorFlow: Upstream a batch of changes to git.
Changes: - Some changes to make our ability to handle external contributions simpler (e.g., adding some markers, adding empty __init__.py files). - Fixing documentation of SummaryWriter.add_summary(). - Some input validation changes for queues/barriers. - Fixing the ptb tutorial (thanks to @kentonl for reporting), fixes, github issue 52. - Some documentation suggestions for dealing with a few install problems. - Speed improvements to conv2d gradient kernels on CPU. - More documentation fixes for the website. Thanks to @makky3939 and @keonkim for reports. - Changes / fixes to the Docker files. - Changes build_pip_package to not create an sdist. - Adds tensorboard example script. Base CL: 107445267
This commit is contained in:
parent
61d3a958d6
commit
9274f5aa47
tensorflow
BUILD
core
kernels
public
examples
g3doc
__init__.py
api_docs
cc
ClassRandomAccessFile.mdClassSession.mdClassStatus.mdClassTensor.mdClassTensorShape.mdClassTensorShapeUtils.md
index.mdget_started
how_tos
tutorials
models
opensource_only
python
client
framework
kernel_tests
lib
ops
platform/default
summary
training
user_ops
util
tensorboard
tools
@ -11,6 +11,8 @@ exports_files([
|
||||
"ACKNOWLEDGMENTS",
|
||||
])
|
||||
|
||||
# open source marker; do not remove
|
||||
|
||||
package_group(
|
||||
name = "internal",
|
||||
packages = ["//tensorflow/..."],
|
||||
|
@ -385,21 +385,25 @@ class Conv2DCustomBackpropInputOp : public OpKernel {
|
||||
auto* out_backprop_data = out_backprop.template flat<T>().data();
|
||||
auto* input_backprop_data = in_backprop->template flat<T>().data();
|
||||
|
||||
typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic,
|
||||
Eigen::RowMajor>> MatrixMap;
|
||||
typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic,
|
||||
Eigen::RowMajor>> ConstMatrixMap;
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>,
|
||||
Eigen::Unaligned> TensorMap;
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>,
|
||||
Eigen::Unaligned> ConstTensorMap;
|
||||
|
||||
// Initialize contraction dims (we need to transpose 'B' below).
|
||||
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_dims;
|
||||
contract_dims[0].first = 1;
|
||||
contract_dims[0].second = 1;
|
||||
|
||||
for (int image_id = 0; image_id < batch; ++image_id) {
|
||||
// Compute gradient into col_buffer.
|
||||
MatrixMap C(col_buffer_data, output_image_size, filter_total_size);
|
||||
TensorMap C(col_buffer_data, output_image_size, filter_total_size);
|
||||
|
||||
ConstMatrixMap A(out_backprop_data + output_offset * image_id,
|
||||
ConstTensorMap A(out_backprop_data + output_offset * image_id,
|
||||
output_image_size, out_depth);
|
||||
ConstMatrixMap B(filter_data, filter_total_size, out_depth);
|
||||
ConstTensorMap B(filter_data, filter_total_size, out_depth);
|
||||
|
||||
// TODO(andydavis) Use a multi-threaded matmul implementation here.
|
||||
C.noalias() = A * B.transpose();
|
||||
C.device(context->eigen_cpu_device()) = A.contract(B, contract_dims);
|
||||
|
||||
Col2im<T>(col_buffer_data, in_depth, input_rows, input_cols, filter_rows,
|
||||
filter_cols, pad_top, pad_left, pad_bottom, pad_right, stride,
|
||||
@ -554,14 +558,19 @@ class Conv2DCustomBackpropFilterOp : public OpKernel {
|
||||
auto* out_backprop_data = out_backprop.template flat<T>().data();
|
||||
auto* filter_backprop_data = filter_backprop->template flat<T>().data();
|
||||
|
||||
typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic,
|
||||
Eigen::RowMajor>> MatrixMap;
|
||||
typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic,
|
||||
Eigen::RowMajor>> ConstMatrixMap;
|
||||
|
||||
MatrixMap C(filter_backprop_data, filter_total_size, out_depth);
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>,
|
||||
Eigen::Unaligned> TensorMap;
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>,
|
||||
Eigen::Unaligned> ConstTensorMap;
|
||||
|
||||
TensorMap C(filter_backprop_data, filter_total_size, out_depth);
|
||||
C.setZero();
|
||||
|
||||
// Initialize contraction dims (we need to transpose 'A' below).
|
||||
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_dims;
|
||||
contract_dims[0].first = 0;
|
||||
contract_dims[0].second = 0;
|
||||
|
||||
for (int image_id = 0; image_id < batch; ++image_id) {
|
||||
// When we compute the gradient with respect to the filters, we need to do
|
||||
// im2col to allow gemm-type computation.
|
||||
@ -569,13 +578,12 @@ class Conv2DCustomBackpropFilterOp : public OpKernel {
|
||||
filter_cols, pad_top, pad_left, pad_bottom, pad_right, stride,
|
||||
stride, col_buffer_data);
|
||||
|
||||
ConstMatrixMap A(col_buffer_data, output_image_size, filter_total_size);
|
||||
ConstMatrixMap B(out_backprop_data + output_offset * image_id,
|
||||
ConstTensorMap A(col_buffer_data, output_image_size, filter_total_size);
|
||||
ConstTensorMap B(out_backprop_data + output_offset * image_id,
|
||||
output_image_size, out_depth);
|
||||
|
||||
// Compute gradient with respect to filter.
|
||||
// TODO(andydavis) Use a multi-threaded matmul implementation here.
|
||||
C.noalias() += A.transpose() * B;
|
||||
// Gradient with respect to filter.
|
||||
C.device(context->eigen_cpu_device()) += A.contract(B, contract_dims);
|
||||
|
||||
input_data += input_offset;
|
||||
}
|
||||
|
@ -426,7 +426,8 @@ void FIFOQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx,
|
||||
attempt->context->SetStatus(
|
||||
errors::DataLoss("Failed to restore element from "
|
||||
"partially-dequeued batch "
|
||||
"to FIFOQueue"));
|
||||
"to FIFOQueue: ",
|
||||
s.error_message()));
|
||||
}
|
||||
queues_[j].push_front(element);
|
||||
}
|
||||
|
@ -9,17 +9,39 @@ namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
template <DataType DT>
|
||||
void HandleSliceToElement(const Tensor& parent, Tensor* element, int index) {
|
||||
Status HandleSliceToElement(const Tensor& parent, Tensor* element, int index) {
|
||||
typedef typename EnumToDataType<DT>::Type T;
|
||||
DCHECK_NE(parent.dim_size(0), 0);
|
||||
if (element->NumElements() != (parent.NumElements() / parent.dim_size(0))) {
|
||||
TensorShape chip_shape = parent.shape();
|
||||
chip_shape.RemoveDim(0);
|
||||
return errors::Internal(
|
||||
"Cannot copy slice: number of elements does not match. Shapes are: "
|
||||
"[element]: ",
|
||||
element->shape().DebugString(), ", [parent slice]: ",
|
||||
chip_shape.DebugString());
|
||||
}
|
||||
auto parent_as_matrix = parent.flat_outer_dims<T>();
|
||||
element->flat<T>() = parent_as_matrix.chip(index, 0);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <DataType DT>
|
||||
void HandleElementToSlice(const Tensor& element, Tensor* parent, int index) {
|
||||
Status HandleElementToSlice(const Tensor& element, Tensor* parent, int index) {
|
||||
typedef typename EnumToDataType<DT>::Type T;
|
||||
DCHECK_NE(parent->dim_size(0), 0);
|
||||
if (element.NumElements() != (parent->NumElements() / parent->dim_size(0))) {
|
||||
TensorShape chip_shape = parent->shape();
|
||||
chip_shape.RemoveDim(0);
|
||||
return errors::Internal(
|
||||
"Cannot copy slice: number of elements does not match. Shapes are: "
|
||||
"[element]: ",
|
||||
element.shape().DebugString(), ", [parent slice]: ",
|
||||
chip_shape.DebugString());
|
||||
}
|
||||
auto parent_as_matrix = parent->flat_outer_dims<T>();
|
||||
parent_as_matrix.chip(index, 0) = element.flat<T>();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@ -27,10 +49,10 @@ void HandleElementToSlice(const Tensor& element, Tensor* parent, int index) {
|
||||
// static
|
||||
Status QueueBase::CopySliceToElement(const Tensor& parent, Tensor* element,
|
||||
int index) {
|
||||
#define HANDLE_TYPE(DT) \
|
||||
if (parent.dtype() == DT) { \
|
||||
HandleSliceToElement<DT>(parent, element, index); \
|
||||
return Status::OK(); \
|
||||
#define HANDLE_TYPE(DT) \
|
||||
if (parent.dtype() == DT) { \
|
||||
TF_RETURN_IF_ERROR(HandleSliceToElement<DT>(parent, element, index)); \
|
||||
return Status::OK(); \
|
||||
}
|
||||
HANDLE_TYPE(DT_FLOAT);
|
||||
HANDLE_TYPE(DT_DOUBLE);
|
||||
@ -47,10 +69,10 @@ Status QueueBase::CopySliceToElement(const Tensor& parent, Tensor* element,
|
||||
// static
|
||||
Status QueueBase::CopyElementToSlice(const Tensor& element, Tensor* parent,
|
||||
int index) {
|
||||
#define HANDLE_TYPE(DT) \
|
||||
if (element.dtype() == DT) { \
|
||||
HandleElementToSlice<DT>(element, parent, index); \
|
||||
return Status::OK(); \
|
||||
#define HANDLE_TYPE(DT) \
|
||||
if (element.dtype() == DT) { \
|
||||
TF_RETURN_IF_ERROR(HandleElementToSlice<DT>(element, parent, index)); \
|
||||
return Status::OK(); \
|
||||
}
|
||||
HANDLE_TYPE(DT_FLOAT);
|
||||
HANDLE_TYPE(DT_DOUBLE);
|
||||
|
@ -25,7 +25,7 @@ class StringToHashBucketOp : public OpKernel {
|
||||
&output_tensor));
|
||||
auto output_flat = output_tensor->flat<int64>();
|
||||
|
||||
for (int i = 0; i < input_flat.size(); ++i) {
|
||||
for (std::size_t i = 0; i < input_flat.size(); ++i) {
|
||||
const uint64 input_hash = Hash64(input_flat(i));
|
||||
const uint64 bucket_id = input_hash % num_buckets_;
|
||||
// The number of buckets is always in the positive range of int64 so is
|
||||
|
@ -33,7 +33,7 @@ class StringToNumberOp : public OpKernel {
|
||||
&output_tensor));
|
||||
auto output_flat = output_tensor->flat<OutputType>();
|
||||
|
||||
for (int i = 0; i < input_flat.size(); ++i) {
|
||||
for (std::size_t i = 0; i < input_flat.size(); ++i) {
|
||||
const char* s = input_flat(i).data();
|
||||
Convert(s, &output_flat(i), context);
|
||||
}
|
||||
|
@ -142,7 +142,7 @@ class RandomAccessFile {
|
||||
/// On OK returned status: "n" bytes have been stored in "*result".
|
||||
/// On non-OK returned status: [0..n] bytes have been stored in "*result".
|
||||
///
|
||||
/// Returns OUT_OF_RANGE if fewer than n bytes were stored in "*result"
|
||||
/// Returns `OUT_OF_RANGE` if fewer than n bytes were stored in "*result"
|
||||
/// because of EOF.
|
||||
///
|
||||
/// Safe for concurrent use by multiple threads.
|
||||
@ -155,7 +155,7 @@ class RandomAccessFile {
|
||||
void operator=(const RandomAccessFile&);
|
||||
};
|
||||
|
||||
/// \brief A file abstraction for sequential writing.
|
||||
/// \brief A file abstraction for sequential writing.
|
||||
///
|
||||
/// The implementation must provide buffering since callers may append
|
||||
/// small fragments at a time to the file.
|
||||
|
@ -80,21 +80,21 @@ class Session {
|
||||
virtual Status Extend(const GraphDef& graph) = 0;
|
||||
|
||||
/// \brief Runs the graph with the provided input tensors and fills
|
||||
/// 'outputs' for the endpoints specified in 'output_tensor_names'.
|
||||
/// `outputs` for the endpoints specified in `output_tensor_names`.
|
||||
/// Runs to but does not return Tensors for the nodes in
|
||||
/// 'target_node_names'.
|
||||
/// `target_node_names`.
|
||||
///
|
||||
/// The order of tensors in 'outputs' will match the order provided
|
||||
/// by 'output_tensor_names'.
|
||||
/// The order of tensors in `outputs` will match the order provided
|
||||
/// by `output_tensor_names`.
|
||||
///
|
||||
/// If Run returns OK(), then outputs->size() will be equal to
|
||||
/// output_tensor_names.size(). If Run does not return OK(), the
|
||||
/// state of outputs is undefined.
|
||||
/// If `Run` returns `OK()`, then `outputs->size()` will be equal to
|
||||
/// `output_tensor_names.size()`. If `Run` does not return `OK()`, the
|
||||
/// state of `outputs` is undefined.
|
||||
///
|
||||
/// REQUIRES: The name of each Tensor of the input or output must
|
||||
/// match a "Tensor endpoint" in the GraphDef passed to Create().
|
||||
/// match a "Tensor endpoint" in the `GraphDef` passed to `Create()`.
|
||||
///
|
||||
/// REQUIRES: outputs is not nullptr if output_tensor_names is non-empty.
|
||||
/// REQUIRES: outputs is not nullptr if `output_tensor_names` is non-empty.
|
||||
virtual Status Run(const std::vector<std::pair<string, Tensor> >& inputs,
|
||||
const std::vector<string>& output_tensor_names,
|
||||
const std::vector<string>& target_node_names,
|
||||
@ -104,7 +104,7 @@ class Session {
|
||||
///
|
||||
/// Closing a session releases the resources used by this session
|
||||
/// on the TensorFlow runtime (specified during session creation by
|
||||
/// the 'SessionOptions::target' field).
|
||||
/// the `SessionOptions::target` field).
|
||||
virtual Status Close() = 0;
|
||||
|
||||
virtual ~Session() {}
|
||||
@ -112,15 +112,15 @@ class Session {
|
||||
|
||||
/// \brief Create a new session with the given options.
|
||||
///
|
||||
/// If a new session object could not be created, this function will
|
||||
/// If a new `Session` object could not be created, this function will
|
||||
/// return nullptr.
|
||||
Session* NewSession(const SessionOptions& options);
|
||||
|
||||
/// \brief Create a new session with the given options.
|
||||
///
|
||||
/// If session creation succeeds, the new Session will be stored in
|
||||
/// *out_session, the caller will take ownership of the returned
|
||||
/// *out_session, and this function will return OK(). Otherwise, this
|
||||
/// If session creation succeeds, the new `Session` will be stored in
|
||||
/// `*out_session`, the caller will take ownership of the returned
|
||||
/// `*out_session`, and this function will return `OK()`. Otherwise, this
|
||||
/// function will return an error status.
|
||||
Status NewSession(const SessionOptions& options, Session** out_session);
|
||||
|
||||
|
@ -39,19 +39,19 @@ class Status {
|
||||
bool operator==(const Status& x) const;
|
||||
bool operator!=(const Status& x) const;
|
||||
|
||||
/// \brief If "ok()", stores "new_status" into *this. If "!ok()", preserves
|
||||
/// the current status, but may augment with additional information
|
||||
/// about "new_status".
|
||||
/// \brief If `ok()`, stores `new_status` into `*this`. If `!ok()`,
|
||||
/// preserves the current status, but may augment with additional
|
||||
/// information about `new_status`.
|
||||
///
|
||||
/// Convenient way of keeping track of the first error encountered.
|
||||
/// Instead of:
|
||||
/// if (overall_status.ok()) overall_status = new_status
|
||||
/// `if (overall_status.ok()) overall_status = new_status`
|
||||
/// Use:
|
||||
/// overall_status.Update(new_status);
|
||||
/// `overall_status.Update(new_status);`
|
||||
void Update(const Status& new_status);
|
||||
|
||||
/// \brief Return a string representation of this status suitable for
|
||||
/// printing. Returns the string "OK" for success.
|
||||
/// printing. Returns the string `"OK"` for success.
|
||||
string ToString() const;
|
||||
|
||||
private:
|
||||
@ -60,8 +60,8 @@ class Status {
|
||||
tensorflow::error::Code code;
|
||||
string msg;
|
||||
};
|
||||
/// OK status has a NULL state_. Otherwise, state_ points to
|
||||
/// a State structure containing the error code and message(s)
|
||||
// OK status has a `NULL` state_. Otherwise, `state_` points to
|
||||
// a `State` structure containing the error code and message(s)
|
||||
State* state_;
|
||||
|
||||
void SlowCopyFrom(const State* src);
|
||||
@ -71,8 +71,8 @@ inline Status::Status(const Status& s)
|
||||
: state_((s.state_ == NULL) ? NULL : new State(*s.state_)) {}
|
||||
|
||||
inline void Status::operator=(const Status& s) {
|
||||
/// The following condition catches both aliasing (when this == &s),
|
||||
/// and the common case where both s and *this are ok.
|
||||
// The following condition catches both aliasing (when this == &s),
|
||||
// and the common case where both s and *this are ok.
|
||||
if (state_ != s.state_) {
|
||||
SlowCopyFrom(s.state_);
|
||||
}
|
||||
|
@ -27,15 +27,15 @@ class Tensor {
|
||||
/// Default Tensor constructor. Creates a 1-dimension, 0-element float tensor.
|
||||
Tensor();
|
||||
|
||||
/// \brief Creates a Tensor of the given datatype and shape.
|
||||
/// \brief Creates a Tensor of the given `type` and `shape`.
|
||||
///
|
||||
/// The underlying buffer is allocated using a CPUAllocator.
|
||||
/// The underlying buffer is allocated using a `CPUAllocator`.
|
||||
Tensor(DataType type, const TensorShape& shape);
|
||||
|
||||
/// \brief Creates a tensor with the input datatype and shape, using the
|
||||
/// allocator 'a' to allocate the underlying buffer.
|
||||
/// \brief Creates a tensor with the input `type` and `shape`, using the
|
||||
/// allocator `a` to allocate the underlying buffer.
|
||||
///
|
||||
/// 'a' must outlive the lifetime of this Tensor.
|
||||
/// `a` must outlive the lifetime of this Tensor.
|
||||
Tensor(Allocator* a, DataType type, const TensorShape& shape);
|
||||
|
||||
/// Creates an uninitialized Tensor of the given data type.
|
||||
@ -54,7 +54,7 @@ class Tensor {
|
||||
/// \brief Convenience accessor for the tensor shape.
|
||||
///
|
||||
/// For all shape accessors, see comments for relevant methods of
|
||||
/// TensorShape in tensor_shape.h.
|
||||
/// `TensorShape` in `tensor_shape.h`.
|
||||
int dims() const { return shape().dims(); }
|
||||
|
||||
/// Convenience accessor for the tensor shape.
|
||||
@ -81,9 +81,9 @@ class Tensor {
|
||||
|
||||
/// \brief Copy the other tensor into this tensor and reshape it.
|
||||
///
|
||||
/// This tensor shares other's underlying storage. Returns
|
||||
/// true iff other.shape() has the same number of elements of the
|
||||
/// given "shape".
|
||||
/// This tensor shares other's underlying storage. Returns `true`
|
||||
/// iff `other.shape()` has the same number of elements of the given
|
||||
/// `shape`.
|
||||
bool CopyFrom(const Tensor& other,
|
||||
const TensorShape& shape) TF_MUST_USE_RESULT {
|
||||
if (other.NumElements() != shape.num_elements()) return false;
|
||||
@ -93,42 +93,41 @@ class Tensor {
|
||||
|
||||
/// \brief Slice this tensor along the 1st dimension.
|
||||
|
||||
/// I.e., the returned
|
||||
/// tensor satisifies returned[i, ...] == this[dim0_start + i, ...].
|
||||
/// I.e., the returned tensor satisifies
|
||||
/// returned[i, ...] == this[dim0_start + i, ...].
|
||||
/// The returned tensor shares the underlying tensor buffer with this
|
||||
/// tensor.
|
||||
///
|
||||
/// NOTE: The returned tensor may not satisfies the same alignment
|
||||
/// requirement as this tensor depending on the shape. The caller
|
||||
/// must check the returned tensor's alignment before calling certain
|
||||
/// methods that have alignment requirement (e.g., flat(), tensor()).
|
||||
/// methods that have alignment requirement (e.g., `flat()`, `tensor()`).
|
||||
///
|
||||
/// REQUIRES: dims() >= 1
|
||||
/// REQUIRES: 0 <= dim0_start <= dim0_limit <= dim_size(0)
|
||||
/// REQUIRES: `dims()` >= 1
|
||||
/// REQUIRES: `0 <= dim0_start <= dim0_limit <= dim_size(0)`
|
||||
Tensor Slice(int64 dim0_start, int64 dim0_limit) const;
|
||||
|
||||
/// \brief Parse "other' and construct the tensor.
|
||||
/// \brief Parse `other` and construct the tensor.
|
||||
|
||||
/// Returns true iff the
|
||||
/// parsing succeeds. If the parsing fails, the state of "*this" is
|
||||
/// unchanged.
|
||||
/// Returns `true` iff the parsing succeeds. If the parsing fails,
|
||||
/// the state of `*this` is unchanged.
|
||||
bool FromProto(const TensorProto& other) TF_MUST_USE_RESULT;
|
||||
bool FromProto(Allocator* a, const TensorProto& other) TF_MUST_USE_RESULT;
|
||||
|
||||
/// \brief Fills in "proto" with "*this" tensor's content.
|
||||
/// \brief Fills in `proto` with `*this` tensor's content.
|
||||
///
|
||||
/// AsProtoField() fills in the repeated field for proto.dtype(), while
|
||||
/// AsProtoTensorContent() encodes the content in proto.tensor_content() in a
|
||||
/// compact form.
|
||||
/// `AsProtoField()` fills in the repeated field for `proto.dtype()`, while
|
||||
/// `AsProtoTensorContent()` encodes the content in `proto.tensor_content()`
|
||||
/// in a compact form.
|
||||
void AsProtoField(TensorProto* proto) const;
|
||||
void AsProtoTensorContent(TensorProto* proto) const;
|
||||
|
||||
/// \brief Return the Tensor data as an Eigen::Tensor with the type and
|
||||
/// sizes of this Tensor.
|
||||
/// \brief Return the tensor data as an `Eigen::Tensor` with the type and
|
||||
/// sizes of this `Tensor`.
|
||||
///
|
||||
/// Use these methods when you know the data type and the number of
|
||||
/// dimensions of the Tensor and you want an Eigen::Tensor
|
||||
/// automatically sized to the Tensor sizes. The implementation check
|
||||
/// dimensions of the Tensor and you want an `Eigen::Tensor`
|
||||
/// automatically sized to the `Tensor` sizes. The implementation check
|
||||
/// fails if either type or sizes mismatch.
|
||||
///
|
||||
/// Example:
|
||||
@ -157,14 +156,14 @@ class Tensor {
|
||||
template <typename T, size_t NDIMS>
|
||||
typename TTypes<T, NDIMS>::Tensor tensor();
|
||||
|
||||
/// \brief Return the Tensor data as an Eigen::Tensor of the data type and a
|
||||
/// \brief Return the tensor data as an `Eigen::Tensor` of the data type and a
|
||||
/// specified shape.
|
||||
///
|
||||
/// These methods allow you to access the data with the dimensions
|
||||
/// and sizes of your choice. You do not need to know the number of
|
||||
/// dimensions of the Tensor to call them. However, they CHECK that
|
||||
/// dimensions of the Tensor to call them. However, they `CHECK` that
|
||||
/// the type matches and the dimensions requested creates an
|
||||
/// Eigen::Tensor with the same number of elements as the Tensor.
|
||||
/// `Eigen::Tensor` with the same number of elements as the tensor.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
@ -231,11 +230,11 @@ class Tensor {
|
||||
typename TTypes<T, NDIMS>::UnalignedTensor unaligned_shaped(
|
||||
gtl::ArraySlice<int64> new_sizes);
|
||||
|
||||
/// \brief Return the Tensor data as a Tensor Map of fixed size 1:
|
||||
/// TensorMap<TensorFixedSize<T, 1>>.
|
||||
/// \brief Return the Tensor data as a `TensorMap` of fixed size 1:
|
||||
/// `TensorMap<TensorFixedSize<T, 1>>`.
|
||||
|
||||
/// Using scalar() allows the compiler to
|
||||
/// perform optimizations as the size of the tensor is known at compile time.
|
||||
/// Using `scalar()` allows the compiler to perform optimizations as
|
||||
/// the size of the tensor is known at compile time.
|
||||
template <typename T>
|
||||
typename TTypes<T>::Scalar scalar();
|
||||
|
||||
@ -297,27 +296,27 @@ class Tensor {
|
||||
template <typename T>
|
||||
typename TTypes<T>::ConstScalar scalar() const;
|
||||
|
||||
/// Render the first max_entries values in *this into a string.
|
||||
/// Render the first `max_entries` values in `*this` into a string.
|
||||
string SummarizeValue(int64 max_entries) const;
|
||||
|
||||
/// A human-readable summary of the Tensor suitable for debugging.
|
||||
/// A human-readable summary of the tensor suitable for debugging.
|
||||
string DebugString() const;
|
||||
|
||||
/// Fill in the TensorDescription proto with metadata about the
|
||||
/// Tensor that is useful for monitoring and debugging.
|
||||
/// Fill in the `TensorDescription` proto with metadata about the
|
||||
/// tensor that is useful for monitoring and debugging.
|
||||
void FillDescription(TensorDescription* description) const;
|
||||
|
||||
/// \brief Returns a StringPiece mapping the current tensor's buffer.
|
||||
/// \brief Returns a `StringPiece` mapping the current tensor's buffer.
|
||||
///
|
||||
/// The returned StringPiece may point to memory location on devices
|
||||
/// The returned `StringPiece` may point to memory location on devices
|
||||
/// that the CPU cannot address directly.
|
||||
///
|
||||
/// NOTE: The underlying Tensor buffer is refcounted, so the lifetime
|
||||
/// of the contents mapped by the StringPiece matches the lifetime of
|
||||
/// NOTE: The underlying tensor buffer is refcounted, so the lifetime
|
||||
/// of the contents mapped by the `StringPiece` matches the lifetime of
|
||||
/// the buffer; callers should arrange to make sure the buffer does
|
||||
/// not get destroyed while the StringPiece is still used.
|
||||
/// not get destroyed while the `StringPiece` is still used.
|
||||
///
|
||||
/// REQUIRES: DataTypeCanUseMemcpy(dtype()).
|
||||
/// REQUIRES: `DataTypeCanUseMemcpy(dtype())`.
|
||||
StringPiece tensor_data() const;
|
||||
|
||||
private:
|
||||
|
@ -18,52 +18,53 @@ class TensorShapeIter; // Declared below
|
||||
/// Manages the dimensions of a Tensor and their sizes.
|
||||
class TensorShape {
|
||||
public:
|
||||
/// \brief Construct a TensorShape from the provided sizes..
|
||||
/// REQUIRES: dim_sizes[i] >= 0
|
||||
/// \brief Construct a `TensorShape` from the provided sizes.
|
||||
/// REQUIRES: `dim_sizes[i] >= 0`
|
||||
explicit TensorShape(gtl::ArraySlice<int64> dim_sizes);
|
||||
TensorShape(std::initializer_list<int64> dim_sizes)
|
||||
: TensorShape(gtl::ArraySlice<int64>(dim_sizes)) {}
|
||||
|
||||
/// REQUIRES: IsValid(proto)
|
||||
/// REQUIRES: `IsValid(proto)`
|
||||
explicit TensorShape(const TensorShapeProto& proto);
|
||||
|
||||
/// Create a tensor shape with no dimensions and one element, which you can
|
||||
/// then call AddDim() on.
|
||||
/// then call `AddDim()` on.
|
||||
TensorShape();
|
||||
|
||||
/// Returns true iff "proto" is a valid tensor shape.
|
||||
/// Returns `true` iff `proto` is a valid tensor shape.
|
||||
static bool IsValid(const TensorShapeProto& proto);
|
||||
|
||||
/// Clear a tensor shape
|
||||
void Clear();
|
||||
|
||||
/// \brief Add a dimension to the end ("inner-most").
|
||||
/// REQUIRES: size >= 0
|
||||
/// REQUIRES: `size >= 0`
|
||||
void AddDim(int64 size);
|
||||
|
||||
/// Appends all the dimensions from shape.
|
||||
/// Appends all the dimensions from `shape`.
|
||||
void AppendShape(const TensorShape& shape);
|
||||
|
||||
/// \brief Insert a dimension somewhere in the TensorShape.
|
||||
/// REQUIRES: "0 <= d <= dims()"
|
||||
/// REQUIRES: size >= 0
|
||||
/// \brief Insert a dimension somewhere in the `TensorShape`.
|
||||
/// REQUIRES: `0 <= d <= dims()`
|
||||
/// REQUIRES: `size >= 0`
|
||||
void InsertDim(int d, int64 size);
|
||||
|
||||
/// \brief Modifies the size of the dimension 'd' to be 'size'
|
||||
/// REQUIRES: "0 <= d < dims()"
|
||||
/// REQUIRES: size >= 0
|
||||
/// \brief Modifies the size of the dimension `d` to be `size`
|
||||
/// REQUIRES: `0 <= d < dims()`
|
||||
/// REQUIRES: `size >= 0`
|
||||
void set_dim(int d, int64 size);
|
||||
|
||||
/// \brief Removes dimension 'd' from the TensorShape.
|
||||
/// REQUIRES: "0 <= d < dims()"
|
||||
/// \brief Removes dimension `d` from the `TensorShape`.
|
||||
/// REQUIRES: `0 <= d < dims()`
|
||||
void RemoveDim(int d);
|
||||
|
||||
/// Return the number of dimensions in the tensor.
|
||||
int dims() const { return dim_sizes_.size(); }
|
||||
|
||||
/// \brief Returns the number of elements in dimension "d".
|
||||
/// REQUIRES: "0 <= d < dims()"
|
||||
// TODO(touts): Rename to dimension() to match Eigen::Tensor::dimension()?
|
||||
/// \brief Returns the number of elements in dimension `d`.
|
||||
/// REQUIRES: `0 <= d < dims()`
|
||||
// TODO(touts): Rename to `dimension()` to match
|
||||
// `Eigen::Tensor::dimension()`?
|
||||
int64 dim_size(int d) const {
|
||||
DCHECK_GE(d, 0);
|
||||
DCHECK_LT(d, dims());
|
||||
@ -75,23 +76,24 @@ class TensorShape {
|
||||
|
||||
/// \brief Returns the number of elements in the tensor.
|
||||
///
|
||||
/// We use int64 and
|
||||
/// not size_t to be compatible with Eigen::Tensor which uses ptr_fi
|
||||
/// We use `int64` and not `size_t` to be compatible with `Eigen::Tensor`
|
||||
/// which uses `ptrdiff_t`.
|
||||
int64 num_elements() const { return num_elements_; }
|
||||
|
||||
/// Returns true if *this and b have the same sizes. Ignores dimension names.
|
||||
/// Returns true if `*this` and `b` have the same sizes. Ignores
|
||||
/// dimension names.
|
||||
bool IsSameSize(const TensorShape& b) const;
|
||||
bool operator==(const TensorShape& b) const { return IsSameSize(b); }
|
||||
|
||||
/// Fill *proto from *this.
|
||||
/// Fill `*proto` from `*this`.
|
||||
void AsProto(TensorShapeProto* proto) const;
|
||||
|
||||
/// Fill *dsizes from *this.
|
||||
/// Fill `*dsizes` from `*this`.
|
||||
template <int NDIMS>
|
||||
Eigen::DSizes<Eigen::DenseIndex, NDIMS> AsEigenDSizes() const;
|
||||
|
||||
/// Same as AsEigenDSizes() but allows for NDIMS > dims() -- in which case we
|
||||
/// pad the rest of the sizes with 1.
|
||||
/// Same as `AsEigenDSizes()` but allows for `NDIMS > dims()` -- in
|
||||
/// which case we pad the rest of the sizes with 1.
|
||||
template <int NDIMS>
|
||||
Eigen::DSizes<Eigen::DenseIndex, NDIMS> AsEigenDSizesWithPadding() const;
|
||||
|
||||
@ -105,14 +107,14 @@ class TensorShape {
|
||||
string ShortDebugString() const;
|
||||
|
||||
private:
|
||||
/// Recalculates the dimensions of this tensor after they are modified.
|
||||
// Recalculates the dimensions of this tensor after they are modified.
|
||||
void recompute_dims();
|
||||
|
||||
// TODO(josh11b): Maybe use something from the Eigen Tensor library
|
||||
/// for the sizes.
|
||||
// for the sizes.
|
||||
gtl::InlinedVector<int64, 4> dim_sizes_;
|
||||
|
||||
/// total number of elements (avoids recomputing it each time).
|
||||
// total number of elements (avoids recomputing it each time).
|
||||
int64 num_elements_;
|
||||
};
|
||||
|
||||
@ -151,7 +153,7 @@ static const bool kAllowLegacyScalars = true;
|
||||
static const bool kAllowLegacyScalars = false;
|
||||
#endif
|
||||
|
||||
/// \brief Static helper routines for TensorShape. Includes a few common
|
||||
/// \brief Static helper routines for `TensorShape`. Includes a few common
|
||||
/// predicates on a tensor shape.
|
||||
class TensorShapeUtils {
|
||||
public:
|
||||
@ -180,8 +182,8 @@ class TensorShapeUtils {
|
||||
return shape.dims() >= 2;
|
||||
}
|
||||
|
||||
/// \brief Returns a TensorShape whose dimensions are dims[0], dims[1], ...,
|
||||
/// dims[n-1].
|
||||
/// \brief Returns a `TensorShape` whose dimensions are
|
||||
/// `dims[0]`, `dims[1]`, ..., `dims[n-1]`.
|
||||
template <typename T>
|
||||
static TensorShape MakeShape(const T* dims, int n) {
|
||||
TensorShape shape;
|
||||
@ -203,11 +205,6 @@ class TensorShapeUtils {
|
||||
static bool StartsWith(const TensorShape& shape0, const TensorShape& shape1);
|
||||
};
|
||||
|
||||
// TODO(josh11b): Add TensorStrides once we support strides
|
||||
// struct TensorStrides {
|
||||
// gtl::InlinedVector<int, 4> strides_;
|
||||
// };
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Template method implementation details below
|
||||
// ----------------------------------------------------------------------------
|
||||
|
0
tensorflow/examples/__init__.py
Normal file
0
tensorflow/examples/__init__.py
Normal file
0
tensorflow/examples/android/__init__.py
Executable file → Normal file
0
tensorflow/examples/android/__init__.py
Executable file → Normal file
0
tensorflow/examples/android/jni/__init__.py
Executable file → Normal file
0
tensorflow/examples/android/jni/__init__.py
Executable file → Normal file
0
tensorflow/g3doc/__init__.py
Executable file → Normal file
0
tensorflow/g3doc/__init__.py
Executable file → Normal file
@ -33,6 +33,6 @@ Reads up to "n" bytes from the file starting at "offset".
|
||||
|
||||
On OK returned status: "n" bytes have been stored in "*result". On non-OK returned status: [0..n] bytes have been stored in "*result".
|
||||
|
||||
Returns OUT_OF_RANGE if fewer than n bytes were stored in "*result" because of EOF.
|
||||
Returns `OUT_OF_RANGE` if fewer than n bytes were stored in "*result" because of EOF.
|
||||
|
||||
Safe for concurrent use by multiple threads.
|
||||
|
@ -48,7 +48,7 @@ Only one thread must call Close() , and Close() must only be called after all ot
|
||||
* [`virtual Status tensorflow::Session::Extend(const GraphDef &graph)=0`](#virtual_Status_tensorflow_Session_Extend)
|
||||
* Adds operations to the graph that is already registered with the Session .
|
||||
* [`virtual Status tensorflow::Session::Run(const std::vector< std::pair< string, Tensor > > &inputs, const std::vector< string > &output_tensor_names, const std::vector< string > &target_node_names, std::vector< Tensor > *outputs)=0`](#virtual_Status_tensorflow_Session_Run)
|
||||
* Runs the graph with the provided input tensors and fills 'outputs' for the endpoints specified in 'output_tensor_names'. Runs to but does not return Tensors for the nodes in 'target_node_names'.
|
||||
* Runs the graph with the provided input tensors and fills `outputs` for the endpoints specified in `output_tensor_names`. Runs to but does not return Tensors for the nodes in `target_node_names`.
|
||||
* [`virtual Status tensorflow::Session::Close()=0`](#virtual_Status_tensorflow_Session_Close)
|
||||
* Closes this session.
|
||||
* [`virtual tensorflow::Session::~Session()`](#virtual_tensorflow_Session_Session)
|
||||
@ -69,21 +69,21 @@ The names of new operations in "graph" must not exist in the graph that is alrea
|
||||
|
||||
#### `virtual Status tensorflow::Session::Run(const std::vector< std::pair< string, Tensor > > &inputs, const std::vector< string > &output_tensor_names, const std::vector< string > &target_node_names, std::vector< Tensor > *outputs)=0` <a class="md-anchor" id="virtual_Status_tensorflow_Session_Run"></a>
|
||||
|
||||
Runs the graph with the provided input tensors and fills 'outputs' for the endpoints specified in 'output_tensor_names'. Runs to but does not return Tensors for the nodes in 'target_node_names'.
|
||||
Runs the graph with the provided input tensors and fills `outputs` for the endpoints specified in `output_tensor_names`. Runs to but does not return Tensors for the nodes in `target_node_names`.
|
||||
|
||||
The order of tensors in 'outputs' will match the order provided by 'output_tensor_names'.
|
||||
The order of tensors in `outputs` will match the order provided by `output_tensor_names`.
|
||||
|
||||
If Run returns OK(), then outputs->size() will be equal to output_tensor_names.size(). If Run does not return OK(), the state of outputs is undefined.
|
||||
If `Run` returns `OK()`, then `outputs->size()` will be equal to `output_tensor_names.size()`. If `Run` does not return `OK()`, the state of `outputs` is undefined.
|
||||
|
||||
REQUIRES: The name of each Tensor of the input or output must match a "Tensor endpoint" in the GraphDef passed to Create() .
|
||||
REQUIRES: The name of each Tensor of the input or output must match a "Tensor endpoint" in the `GraphDef` passed to ` Create() `.
|
||||
|
||||
REQUIRES: outputs is not nullptr if output_tensor_names is non-empty.
|
||||
REQUIRES: outputs is not nullptr if `output_tensor_names` is non-empty.
|
||||
|
||||
#### `virtual Status tensorflow::Session::Close()=0` <a class="md-anchor" id="virtual_Status_tensorflow_Session_Close"></a>
|
||||
|
||||
Closes this session.
|
||||
|
||||
Closing a session releases the resources used by this session on the TensorFlow runtime (specified during session creation by the ' SessionOptions::target ' field).
|
||||
Closing a session releases the resources used by this session on the TensorFlow runtime (specified during session creation by the ` SessionOptions::target ` field).
|
||||
|
||||
#### `virtual tensorflow::Session::~Session()` <a class="md-anchor" id="virtual_tensorflow_Session_Session"></a>
|
||||
|
||||
|
@ -21,9 +21,9 @@
|
||||
* [`bool tensorflow::Status::operator==(const Status &x) const`](#bool_tensorflow_Status_operator_)
|
||||
* [`bool tensorflow::Status::operator!=(const Status &x) const`](#bool_tensorflow_Status_operator_)
|
||||
* [`void tensorflow::Status::Update(const Status &new_status)`](#void_tensorflow_Status_Update)
|
||||
* If "ok()", stores "new_status" into *this. If "!ok()", preserves the current status, but may augment with additional information about "new_status".
|
||||
* If ` ok() `, stores `new_status` into `*this`. If `!ok()`, preserves the current status, but may augment with additional information about `new_status`.
|
||||
* [`string tensorflow::Status::ToString() const`](#string_tensorflow_Status_ToString)
|
||||
* Return a string representation of this status suitable for printing. Returns the string "OK" for success.
|
||||
* Return a string representation of this status suitable for printing. Returns the string `"OK"` for success.
|
||||
* [`static Status tensorflow::Status::OK()`](#static_Status_tensorflow_Status_OK)
|
||||
|
||||
##Member Details <a class="md-anchor" id="AUTOGENERATED-member-details"></a>
|
||||
@ -90,13 +90,13 @@ Returns true iff the status indicates success.
|
||||
|
||||
#### `void tensorflow::Status::Update(const Status &new_status)` <a class="md-anchor" id="void_tensorflow_Status_Update"></a>
|
||||
|
||||
If "ok()", stores "new_status" into *this. If "!ok()", preserves the current status, but may augment with additional information about "new_status".
|
||||
If ` ok() `, stores `new_status` into `*this`. If `!ok()`, preserves the current status, but may augment with additional information about `new_status`.
|
||||
|
||||
Convenient way of keeping track of the first error encountered. Instead of: if (overall_status.ok()) overall_status = new_status Use: overall_status.Update(new_status);
|
||||
Convenient way of keeping track of the first error encountered. Instead of: `if (overall_status.ok()) overall_status = new_status` Use: `overall_status.Update(new_status);`
|
||||
|
||||
#### `string tensorflow::Status::ToString() const` <a class="md-anchor" id="string_tensorflow_Status_ToString"></a>
|
||||
|
||||
Return a string representation of this status suitable for printing. Returns the string "OK" for success.
|
||||
Return a string representation of this status suitable for printing. Returns the string `"OK"` for success.
|
||||
|
||||
|
||||
|
||||
|
@ -9,9 +9,9 @@ Represents an n-dimensional array of values.
|
||||
* [`tensorflow::Tensor::Tensor()`](#tensorflow_Tensor_Tensor)
|
||||
* Default Tensor constructor. Creates a 1-dimension, 0-element float tensor.
|
||||
* [`tensorflow::Tensor::Tensor(DataType type, const TensorShape &shape)`](#tensorflow_Tensor_Tensor)
|
||||
* Creates a Tensor of the given datatype and shape.
|
||||
* Creates a Tensor of the given `type` and `shape`.
|
||||
* [`tensorflow::Tensor::Tensor(Allocator *a, DataType type, const TensorShape &shape)`](#tensorflow_Tensor_Tensor)
|
||||
* Creates a tensor with the input datatype and shape, using the allocator 'a' to allocate the underlying buffer.
|
||||
* Creates a tensor with the input `type` and `shape`, using the allocator `a` to allocate the underlying buffer.
|
||||
* [`tensorflow::Tensor::Tensor(DataType type)`](#tensorflow_Tensor_Tensor)
|
||||
* Creates an uninitialized Tensor of the given data type.
|
||||
* [`tensorflow::Tensor::Tensor(const Tensor &other)`](#tensorflow_Tensor_Tensor)
|
||||
@ -39,24 +39,24 @@ Represents an n-dimensional array of values.
|
||||
* [`Tensor tensorflow::Tensor::Slice(int64 dim0_start, int64 dim0_limit) const`](#Tensor_tensorflow_Tensor_Slice)
|
||||
* Slice this tensor along the 1st dimension.
|
||||
* [`bool tensorflow::Tensor::FromProto(const TensorProto &other) TF_MUST_USE_RESULT`](#bool_tensorflow_Tensor_FromProto)
|
||||
* Parse "other' and construct the tensor.
|
||||
* Parse `other` and construct the tensor.
|
||||
* [`bool tensorflow::Tensor::FromProto(Allocator *a, const TensorProto &other) TF_MUST_USE_RESULT`](#bool_tensorflow_Tensor_FromProto)
|
||||
* [`void tensorflow::Tensor::AsProtoField(TensorProto *proto) const`](#void_tensorflow_Tensor_AsProtoField)
|
||||
* Fills in "proto" with "*this" tensor's content.
|
||||
* Fills in `proto` with `*this` tensor's content.
|
||||
* [`void tensorflow::Tensor::AsProtoTensorContent(TensorProto *proto) const`](#void_tensorflow_Tensor_AsProtoTensorContent)
|
||||
* [`TTypes<T>::Vec tensorflow::Tensor::vec()`](#TTypes_T_Vec_tensorflow_Tensor_vec)
|
||||
* Return the Tensor data as an Eigen::Tensor with the type and sizes of this Tensor .
|
||||
* Return the tensor data as an `Eigen::Tensor` with the type and sizes of this ` Tensor `.
|
||||
* [`TTypes<T>::Matrix tensorflow::Tensor::matrix()`](#TTypes_T_Matrix_tensorflow_Tensor_matrix)
|
||||
* [`TTypes< T, NDIMS >::Tensor tensorflow::Tensor::tensor()`](#TTypes_T_NDIMS_Tensor_tensorflow_Tensor_tensor)
|
||||
* [`TTypes<T>::Flat tensorflow::Tensor::flat()`](#TTypes_T_Flat_tensorflow_Tensor_flat)
|
||||
* Return the Tensor data as an Eigen::Tensor of the data type and a specified shape.
|
||||
* Return the tensor data as an `Eigen::Tensor` of the data type and a specified shape.
|
||||
* [`TTypes<T>::UnalignedFlat tensorflow::Tensor::unaligned_flat()`](#TTypes_T_UnalignedFlat_tensorflow_Tensor_unaligned_flat)
|
||||
* [`TTypes<T>::Matrix tensorflow::Tensor::flat_inner_dims()`](#TTypes_T_Matrix_tensorflow_Tensor_flat_inner_dims)
|
||||
* [`TTypes<T>::Matrix tensorflow::Tensor::flat_outer_dims()`](#TTypes_T_Matrix_tensorflow_Tensor_flat_outer_dims)
|
||||
* [`TTypes< T, NDIMS >::Tensor tensorflow::Tensor::shaped(gtl::ArraySlice< int64 > new_sizes)`](#TTypes_T_NDIMS_Tensor_tensorflow_Tensor_shaped)
|
||||
* [`TTypes< T, NDIMS >::UnalignedTensor tensorflow::Tensor::unaligned_shaped(gtl::ArraySlice< int64 > new_sizes)`](#TTypes_T_NDIMS_UnalignedTensor_tensorflow_Tensor_unaligned_shaped)
|
||||
* [`TTypes< T >::Scalar tensorflow::Tensor::scalar()`](#TTypes_T_Scalar_tensorflow_Tensor_scalar)
|
||||
* Return the Tensor data as a Tensor Map of fixed size 1: TensorMap<TensorFixedSize<T, 1>>.
|
||||
* Return the Tensor data as a `TensorMap` of fixed size 1: `TensorMap<TensorFixedSize<T, 1>>`.
|
||||
* [`TTypes<T>::ConstVec tensorflow::Tensor::vec() const`](#TTypes_T_ConstVec_tensorflow_Tensor_vec)
|
||||
* Const versions of all the methods above.
|
||||
* [`TTypes<T>::ConstMatrix tensorflow::Tensor::matrix() const`](#TTypes_T_ConstMatrix_tensorflow_Tensor_matrix)
|
||||
@ -69,12 +69,12 @@ Represents an n-dimensional array of values.
|
||||
* [`TTypes< T, NDIMS >::UnalignedConstTensor tensorflow::Tensor::unaligned_shaped(gtl::ArraySlice< int64 > new_sizes) const`](#TTypes_T_NDIMS_UnalignedConstTensor_tensorflow_Tensor_unaligned_shaped)
|
||||
* [`TTypes< T >::ConstScalar tensorflow::Tensor::scalar() const`](#TTypes_T_ConstScalar_tensorflow_Tensor_scalar)
|
||||
* [`string tensorflow::Tensor::SummarizeValue(int64 max_entries) const`](#string_tensorflow_Tensor_SummarizeValue)
|
||||
* Render the first max_entries values in *this into a string.
|
||||
* Render the first `max_entries` values in `*this` into a string.
|
||||
* [`string tensorflow::Tensor::DebugString() const`](#string_tensorflow_Tensor_DebugString)
|
||||
* A human-readable summary of the Tensor suitable for debugging.
|
||||
* A human-readable summary of the tensor suitable for debugging.
|
||||
* [`void tensorflow::Tensor::FillDescription(TensorDescription *description) const`](#void_tensorflow_Tensor_FillDescription)
|
||||
* [`StringPiece tensorflow::Tensor::tensor_data() const`](#StringPiece_tensorflow_Tensor_tensor_data)
|
||||
* Returns a StringPiece mapping the current tensor's buffer.
|
||||
* Returns a `StringPiece` mapping the current tensor's buffer.
|
||||
|
||||
##Member Details <a class="md-anchor" id="AUTOGENERATED-member-details"></a>
|
||||
|
||||
@ -86,15 +86,15 @@ Default Tensor constructor. Creates a 1-dimension, 0-element float tensor.
|
||||
|
||||
#### `tensorflow::Tensor::Tensor(DataType type, const TensorShape &shape)` <a class="md-anchor" id="tensorflow_Tensor_Tensor"></a>
|
||||
|
||||
Creates a Tensor of the given datatype and shape.
|
||||
Creates a Tensor of the given `type` and `shape`.
|
||||
|
||||
The underlying buffer is allocated using a CPUAllocator.
|
||||
The underlying buffer is allocated using a `CPUAllocator`.
|
||||
|
||||
#### `tensorflow::Tensor::Tensor(Allocator *a, DataType type, const TensorShape &shape)` <a class="md-anchor" id="tensorflow_Tensor_Tensor"></a>
|
||||
|
||||
Creates a tensor with the input datatype and shape, using the allocator 'a' to allocate the underlying buffer.
|
||||
Creates a tensor with the input `type` and `shape`, using the allocator `a` to allocate the underlying buffer.
|
||||
|
||||
'a' must outlive the lifetime of this Tensor .
|
||||
`a` must outlive the lifetime of this Tensor .
|
||||
|
||||
#### `tensorflow::Tensor::Tensor(DataType type)` <a class="md-anchor" id="tensorflow_Tensor_Tensor"></a>
|
||||
|
||||
@ -130,7 +130,7 @@ Returns the shape of the tensor.
|
||||
|
||||
Convenience accessor for the tensor shape.
|
||||
|
||||
For all shape accessors, see comments for relevant methods of TensorShape in tensor_shape.h .
|
||||
For all shape accessors, see comments for relevant methods of ` TensorShape ` in ` tensor_shape.h `.
|
||||
|
||||
#### `int64 tensorflow::Tensor::dim_size(int d) const` <a class="md-anchor" id="int64_tensorflow_Tensor_dim_size"></a>
|
||||
|
||||
@ -172,7 +172,7 @@ Assign operator. This tensor shares other's underlying storage.
|
||||
|
||||
Copy the other tensor into this tensor and reshape it.
|
||||
|
||||
This tensor shares other's underlying storage. Returns true iff other.shape() has the same number of elements of the given "shape".
|
||||
This tensor shares other's underlying storage. Returns `true` iff `other.shape()` has the same number of elements of the given `shape`.
|
||||
|
||||
#### `Tensor tensorflow::Tensor::Slice(int64 dim0_start, int64 dim0_limit) const` <a class="md-anchor" id="Tensor_tensorflow_Tensor_Slice"></a>
|
||||
|
||||
@ -180,15 +180,15 @@ Slice this tensor along the 1st dimension.
|
||||
|
||||
I.e., the returned tensor satisifies returned[i, ...] == this[dim0_start + i, ...]. The returned tensor shares the underlying tensor buffer with this tensor.
|
||||
|
||||
NOTE: The returned tensor may not satisfies the same alignment requirement as this tensor depending on the shape. The caller must check the returned tensor's alignment before calling certain methods that have alignment requirement (e.g., flat() , tensor()).
|
||||
NOTE: The returned tensor may not satisfies the same alignment requirement as this tensor depending on the shape. The caller must check the returned tensor's alignment before calling certain methods that have alignment requirement (e.g., ` flat() `, `tensor()`).
|
||||
|
||||
REQUIRES: dims() >= 1 REQUIRES: 0 <= dim0_start <= dim0_limit <= dim_size(0)
|
||||
REQUIRES: ` dims() ` >= 1 REQUIRES: `0 <= dim0_start <= dim0_limit <= dim_size(0)`
|
||||
|
||||
#### `bool tensorflow::Tensor::FromProto(const TensorProto &other) TF_MUST_USE_RESULT` <a class="md-anchor" id="bool_tensorflow_Tensor_FromProto"></a>
|
||||
|
||||
Parse "other' and construct the tensor.
|
||||
Parse `other` and construct the tensor.
|
||||
|
||||
Returns true iff the parsing succeeds. If the parsing fails, the state of "*this" is unchanged.
|
||||
Returns `true` iff the parsing succeeds. If the parsing fails, the state of `*this` is unchanged.
|
||||
|
||||
#### `bool tensorflow::Tensor::FromProto(Allocator *a, const TensorProto &other) TF_MUST_USE_RESULT` <a class="md-anchor" id="bool_tensorflow_Tensor_FromProto"></a>
|
||||
|
||||
@ -198,9 +198,9 @@ Returns true iff the parsing succeeds. If the parsing fails, the state of "*this
|
||||
|
||||
#### `void tensorflow::Tensor::AsProtoField(TensorProto *proto) const` <a class="md-anchor" id="void_tensorflow_Tensor_AsProtoField"></a>
|
||||
|
||||
Fills in "proto" with "*this" tensor's content.
|
||||
Fills in `proto` with `*this` tensor's content.
|
||||
|
||||
AsProtoField() fills in the repeated field for proto.dtype(), while AsProtoTensorContent() encodes the content in proto.tensor_content() in a compact form.
|
||||
` AsProtoField() ` fills in the repeated field for `proto.dtype()`, while `AsProtoTensorContent()` encodes the content in `proto.tensor_content()` in a compact form.
|
||||
|
||||
#### `void tensorflow::Tensor::AsProtoTensorContent(TensorProto *proto) const` <a class="md-anchor" id="void_tensorflow_Tensor_AsProtoTensorContent"></a>
|
||||
|
||||
@ -210,9 +210,9 @@ AsProtoField() fills in the repeated field for proto.dtype(), while AsProtoTenso
|
||||
|
||||
#### `TTypes<T>::Vec tensorflow::Tensor::vec()` <a class="md-anchor" id="TTypes_T_Vec_tensorflow_Tensor_vec"></a>
|
||||
|
||||
Return the Tensor data as an Eigen::Tensor with the type and sizes of this Tensor .
|
||||
Return the tensor data as an `Eigen::Tensor` with the type and sizes of this ` Tensor `.
|
||||
|
||||
Use these methods when you know the data type and the number of dimensions of the Tensor and you want an Eigen::Tensor automatically sized to the Tensor sizes. The implementation check fails if either type or sizes mismatch.
|
||||
Use these methods when you know the data type and the number of dimensions of the Tensor and you want an `Eigen::Tensor` automatically sized to the ` Tensor ` sizes. The implementation check fails if either type or sizes mismatch.
|
||||
|
||||
Example:
|
||||
|
||||
@ -240,9 +240,9 @@ auto mat = my_mat.matrix<int32>();// CHECK fails as type mismatch.
|
||||
|
||||
#### `TTypes<T>::Flat tensorflow::Tensor::flat()` <a class="md-anchor" id="TTypes_T_Flat_tensorflow_Tensor_flat"></a>
|
||||
|
||||
Return the Tensor data as an Eigen::Tensor of the data type and a specified shape.
|
||||
Return the tensor data as an `Eigen::Tensor` of the data type and a specified shape.
|
||||
|
||||
These methods allow you to access the data with the dimensions and sizes of your choice. You do not need to know the number of dimensions of the Tensor to call them. However, they CHECK that the type matches and the dimensions requested creates an Eigen::Tensor with the same number of elements as the Tensor .
|
||||
These methods allow you to access the data with the dimensions and sizes of your choice. You do not need to know the number of dimensions of the Tensor to call them. However, they `CHECK` that the type matches and the dimensions requested creates an `Eigen::Tensor` with the same number of elements as the tensor.
|
||||
|
||||
Example:
|
||||
|
||||
@ -295,9 +295,9 @@ Returns the data as an Eigen::Tensor with 2 dimensions, collapsing all Tensor di
|
||||
|
||||
#### `TTypes< T >::Scalar tensorflow::Tensor::scalar()` <a class="md-anchor" id="TTypes_T_Scalar_tensorflow_Tensor_scalar"></a>
|
||||
|
||||
Return the Tensor data as a Tensor Map of fixed size 1: TensorMap<TensorFixedSize<T, 1>>.
|
||||
Return the Tensor data as a `TensorMap` of fixed size 1: `TensorMap<TensorFixedSize<T, 1>>`.
|
||||
|
||||
Using scalar() allows the compiler to perform optimizations as the size of the tensor is known at compile time.
|
||||
Using ` scalar() ` allows the compiler to perform optimizations as the size of the tensor is known at compile time.
|
||||
|
||||
#### `TTypes<T>::ConstVec tensorflow::Tensor::vec() const` <a class="md-anchor" id="TTypes_T_ConstVec_tensorflow_Tensor_vec"></a>
|
||||
|
||||
@ -361,13 +361,13 @@ Const versions of all the methods above.
|
||||
|
||||
#### `string tensorflow::Tensor::SummarizeValue(int64 max_entries) const` <a class="md-anchor" id="string_tensorflow_Tensor_SummarizeValue"></a>
|
||||
|
||||
Render the first max_entries values in *this into a string.
|
||||
Render the first `max_entries` values in `*this` into a string.
|
||||
|
||||
|
||||
|
||||
#### `string tensorflow::Tensor::DebugString() const` <a class="md-anchor" id="string_tensorflow_Tensor_DebugString"></a>
|
||||
|
||||
A human-readable summary of the Tensor suitable for debugging.
|
||||
A human-readable summary of the tensor suitable for debugging.
|
||||
|
||||
|
||||
|
||||
@ -375,14 +375,14 @@ A human-readable summary of the Tensor suitable for debugging.
|
||||
|
||||
|
||||
|
||||
Fill in the TensorDescription proto with metadata about the Tensor that is useful for monitoring and debugging.
|
||||
Fill in the `TensorDescription` proto with metadata about the tensor that is useful for monitoring and debugging.
|
||||
|
||||
#### `StringPiece tensorflow::Tensor::tensor_data() const` <a class="md-anchor" id="StringPiece_tensorflow_Tensor_tensor_data"></a>
|
||||
|
||||
Returns a StringPiece mapping the current tensor's buffer.
|
||||
Returns a `StringPiece` mapping the current tensor's buffer.
|
||||
|
||||
The returned StringPiece may point to memory location on devices that the CPU cannot address directly.
|
||||
The returned `StringPiece` may point to memory location on devices that the CPU cannot address directly.
|
||||
|
||||
NOTE: The underlying Tensor buffer is refcounted, so the lifetime of the contents mapped by the StringPiece matches the lifetime of the buffer; callers should arrange to make sure the buffer does not get destroyed while the StringPiece is still used.
|
||||
NOTE: The underlying tensor buffer is refcounted, so the lifetime of the contents mapped by the `StringPiece` matches the lifetime of the buffer; callers should arrange to make sure the buffer does not get destroyed while the `StringPiece` is still used.
|
||||
|
||||
REQUIRES: DataTypeCanUseMemcpy( dtype() ).
|
||||
REQUIRES: `DataTypeCanUseMemcpy( dtype() )`.
|
||||
|
@ -7,38 +7,37 @@ Manages the dimensions of a Tensor and their sizes.
|
||||
##Member Summary <a class="md-anchor" id="AUTOGENERATED-member-summary"></a>
|
||||
|
||||
* [`tensorflow::TensorShape::TensorShape(gtl::ArraySlice< int64 > dim_sizes)`](#tensorflow_TensorShape_TensorShape)
|
||||
* Construct a TensorShape from the provided sizes.. REQUIRES: dim_sizes[i] >= 0.
|
||||
* Construct a ` TensorShape ` from the provided sizes. REQUIRES: `dim_sizes[i] >= 0`
|
||||
* [`tensorflow::TensorShape::TensorShape(std::initializer_list< int64 > dim_sizes)`](#tensorflow_TensorShape_TensorShape)
|
||||
* [`tensorflow::TensorShape::TensorShape(const TensorShapeProto &proto)`](#tensorflow_TensorShape_TensorShape)
|
||||
* REQUIRES: IsValid(proto)
|
||||
* REQUIRES: `IsValid(proto)`
|
||||
* [`tensorflow::TensorShape::TensorShape()`](#tensorflow_TensorShape_TensorShape)
|
||||
* [`void tensorflow::TensorShape::Clear()`](#void_tensorflow_TensorShape_Clear)
|
||||
* Clear a tensor shape.
|
||||
* [`void tensorflow::TensorShape::AddDim(int64 size)`](#void_tensorflow_TensorShape_AddDim)
|
||||
* Add a dimension to the end ("inner-most"). REQUIRES: size >= 0.
|
||||
* Add a dimension to the end ("inner-most"). REQUIRES: `size >= 0`
|
||||
* [`void tensorflow::TensorShape::AppendShape(const TensorShape &shape)`](#void_tensorflow_TensorShape_AppendShape)
|
||||
* Appends all the dimensions from shape.
|
||||
* Appends all the dimensions from `shape`.
|
||||
* [`void tensorflow::TensorShape::InsertDim(int d, int64 size)`](#void_tensorflow_TensorShape_InsertDim)
|
||||
* Insert a dimension somewhere in the TensorShape . REQUIRES: "0 <= d <= dims()" REQUIRES: size >= 0.
|
||||
* Insert a dimension somewhere in the ` TensorShape `. REQUIRES: `0 <= d <= dims() ` REQUIRES: `size >= 0`
|
||||
* [`void tensorflow::TensorShape::set_dim(int d, int64 size)`](#void_tensorflow_TensorShape_set_dim)
|
||||
* Modifies the size of the dimension 'd' to be 'size' REQUIRES: "0 <= d < dims()" REQUIRES: size >= 0.
|
||||
* Modifies the size of the dimension `d` to be `size` REQUIRES: `0 <= d < dims() ` REQUIRES: `size >= 0`
|
||||
* [`void tensorflow::TensorShape::RemoveDim(int d)`](#void_tensorflow_TensorShape_RemoveDim)
|
||||
* Removes dimension 'd' from the TensorShape . REQUIRES: "0 <= d < dims()".
|
||||
* Removes dimension `d` from the ` TensorShape `. REQUIRES: `0 <= d < dims() `
|
||||
* [`int tensorflow::TensorShape::dims() const`](#int_tensorflow_TensorShape_dims)
|
||||
* Return the number of dimensions in the tensor.
|
||||
* [`int64 tensorflow::TensorShape::dim_size(int d) const`](#int64_tensorflow_TensorShape_dim_size)
|
||||
* Returns the number of elements in dimension "d". REQUIRES: "0 <= d < dims()".
|
||||
* Returns the number of elements in dimension `d`. REQUIRES: `0 <= d < dims() `
|
||||
* [`gtl::ArraySlice<int64> tensorflow::TensorShape::dim_sizes() const`](#gtl_ArraySlice_int64_tensorflow_TensorShape_dim_sizes)
|
||||
* Returns sizes of all dimensions.
|
||||
* [`int64 tensorflow::TensorShape::num_elements() const`](#int64_tensorflow_TensorShape_num_elements)
|
||||
* Returns the number of elements in the tensor.
|
||||
* [`bool tensorflow::TensorShape::IsSameSize(const TensorShape &b) const`](#bool_tensorflow_TensorShape_IsSameSize)
|
||||
* Returns true if *this and b have the same sizes. Ignores dimension names.
|
||||
* [`bool tensorflow::TensorShape::operator==(const TensorShape &b) const`](#bool_tensorflow_TensorShape_operator_)
|
||||
* [`void tensorflow::TensorShape::AsProto(TensorShapeProto *proto) const`](#void_tensorflow_TensorShape_AsProto)
|
||||
* Fill *proto from *this.
|
||||
* Fill `*proto` from `*this`.
|
||||
* [`Eigen::DSizes< Eigen::DenseIndex, NDIMS > tensorflow::TensorShape::AsEigenDSizes() const`](#Eigen_DSizes_Eigen_DenseIndex_NDIMS_tensorflow_TensorShape_AsEigenDSizes)
|
||||
* Fill *dsizes from *this.
|
||||
* Fill `*dsizes` from `*this`.
|
||||
* [`Eigen::DSizes< Eigen::DenseIndex, NDIMS > tensorflow::TensorShape::AsEigenDSizesWithPadding() const`](#Eigen_DSizes_Eigen_DenseIndex_NDIMS_tensorflow_TensorShape_AsEigenDSizesWithPadding)
|
||||
* [`TensorShapeIter tensorflow::TensorShape::begin() const`](#TensorShapeIter_tensorflow_TensorShape_begin)
|
||||
* For iterating through the dimensions.
|
||||
@ -47,13 +46,13 @@ Manages the dimensions of a Tensor and their sizes.
|
||||
* For error messages.
|
||||
* [`string tensorflow::TensorShape::ShortDebugString() const`](#string_tensorflow_TensorShape_ShortDebugString)
|
||||
* [`static bool tensorflow::TensorShape::IsValid(const TensorShapeProto &proto)`](#static_bool_tensorflow_TensorShape_IsValid)
|
||||
* Returns true iff "proto" is a valid tensor shape.
|
||||
* Returns `true` iff `proto` is a valid tensor shape.
|
||||
|
||||
##Member Details <a class="md-anchor" id="AUTOGENERATED-member-details"></a>
|
||||
|
||||
#### `tensorflow::TensorShape::TensorShape(gtl::ArraySlice< int64 > dim_sizes)` <a class="md-anchor" id="tensorflow_TensorShape_TensorShape"></a>
|
||||
|
||||
Construct a TensorShape from the provided sizes.. REQUIRES: dim_sizes[i] >= 0.
|
||||
Construct a ` TensorShape ` from the provided sizes. REQUIRES: `dim_sizes[i] >= 0`
|
||||
|
||||
|
||||
|
||||
@ -65,7 +64,7 @@ Construct a TensorShape from the provided sizes.. REQUIRES: dim_sizes[i] >= 0.
|
||||
|
||||
#### `tensorflow::TensorShape::TensorShape(const TensorShapeProto &proto)` <a class="md-anchor" id="tensorflow_TensorShape_TensorShape"></a>
|
||||
|
||||
REQUIRES: IsValid(proto)
|
||||
REQUIRES: `IsValid(proto)`
|
||||
|
||||
|
||||
|
||||
@ -73,7 +72,7 @@ REQUIRES: IsValid(proto)
|
||||
|
||||
|
||||
|
||||
Create a tensor shape with no dimensions and one element, which you can then call AddDim() on.
|
||||
Create a tensor shape with no dimensions and one element, which you can then call ` AddDim() ` on.
|
||||
|
||||
#### `void tensorflow::TensorShape::Clear()` <a class="md-anchor" id="void_tensorflow_TensorShape_Clear"></a>
|
||||
|
||||
@ -83,31 +82,31 @@ Clear a tensor shape.
|
||||
|
||||
#### `void tensorflow::TensorShape::AddDim(int64 size)` <a class="md-anchor" id="void_tensorflow_TensorShape_AddDim"></a>
|
||||
|
||||
Add a dimension to the end ("inner-most"). REQUIRES: size >= 0.
|
||||
Add a dimension to the end ("inner-most"). REQUIRES: `size >= 0`
|
||||
|
||||
|
||||
|
||||
#### `void tensorflow::TensorShape::AppendShape(const TensorShape &shape)` <a class="md-anchor" id="void_tensorflow_TensorShape_AppendShape"></a>
|
||||
|
||||
Appends all the dimensions from shape.
|
||||
Appends all the dimensions from `shape`.
|
||||
|
||||
|
||||
|
||||
#### `void tensorflow::TensorShape::InsertDim(int d, int64 size)` <a class="md-anchor" id="void_tensorflow_TensorShape_InsertDim"></a>
|
||||
|
||||
Insert a dimension somewhere in the TensorShape . REQUIRES: "0 <= d <= dims()" REQUIRES: size >= 0.
|
||||
Insert a dimension somewhere in the ` TensorShape `. REQUIRES: `0 <= d <= dims() ` REQUIRES: `size >= 0`
|
||||
|
||||
|
||||
|
||||
#### `void tensorflow::TensorShape::set_dim(int d, int64 size)` <a class="md-anchor" id="void_tensorflow_TensorShape_set_dim"></a>
|
||||
|
||||
Modifies the size of the dimension 'd' to be 'size' REQUIRES: "0 <= d < dims()" REQUIRES: size >= 0.
|
||||
Modifies the size of the dimension `d` to be `size` REQUIRES: `0 <= d < dims() ` REQUIRES: `size >= 0`
|
||||
|
||||
|
||||
|
||||
#### `void tensorflow::TensorShape::RemoveDim(int d)` <a class="md-anchor" id="void_tensorflow_TensorShape_RemoveDim"></a>
|
||||
|
||||
Removes dimension 'd' from the TensorShape . REQUIRES: "0 <= d < dims()".
|
||||
Removes dimension `d` from the ` TensorShape `. REQUIRES: `0 <= d < dims() `
|
||||
|
||||
|
||||
|
||||
@ -119,7 +118,7 @@ Return the number of dimensions in the tensor.
|
||||
|
||||
#### `int64 tensorflow::TensorShape::dim_size(int d) const` <a class="md-anchor" id="int64_tensorflow_TensorShape_dim_size"></a>
|
||||
|
||||
Returns the number of elements in dimension "d". REQUIRES: "0 <= d < dims()".
|
||||
Returns the number of elements in dimension `d`. REQUIRES: `0 <= d < dims() `
|
||||
|
||||
|
||||
|
||||
@ -133,13 +132,13 @@ Returns sizes of all dimensions.
|
||||
|
||||
Returns the number of elements in the tensor.
|
||||
|
||||
We use int64 and not size_t to be compatible with Eigen::Tensor which uses ptr_fi
|
||||
We use `int64` and not `size_t` to be compatible with `Eigen::Tensor` which uses `ptrdiff_t`.
|
||||
|
||||
#### `bool tensorflow::TensorShape::IsSameSize(const TensorShape &b) const` <a class="md-anchor" id="bool_tensorflow_TensorShape_IsSameSize"></a>
|
||||
|
||||
Returns true if *this and b have the same sizes. Ignores dimension names.
|
||||
|
||||
|
||||
Returns true if `*this` and `b` have the same sizes. Ignores dimension names.
|
||||
|
||||
#### `bool tensorflow::TensorShape::operator==(const TensorShape &b) const` <a class="md-anchor" id="bool_tensorflow_TensorShape_operator_"></a>
|
||||
|
||||
@ -149,13 +148,13 @@ Returns true if *this and b have the same sizes. Ignores dimension names.
|
||||
|
||||
#### `void tensorflow::TensorShape::AsProto(TensorShapeProto *proto) const` <a class="md-anchor" id="void_tensorflow_TensorShape_AsProto"></a>
|
||||
|
||||
Fill *proto from *this.
|
||||
Fill `*proto` from `*this`.
|
||||
|
||||
|
||||
|
||||
#### `Eigen::DSizes< Eigen::DenseIndex, NDIMS > tensorflow::TensorShape::AsEigenDSizes() const` <a class="md-anchor" id="Eigen_DSizes_Eigen_DenseIndex_NDIMS_tensorflow_TensorShape_AsEigenDSizes"></a>
|
||||
|
||||
Fill *dsizes from *this.
|
||||
Fill `*dsizes` from `*this`.
|
||||
|
||||
|
||||
|
||||
@ -163,7 +162,7 @@ Fill *dsizes from *this.
|
||||
|
||||
|
||||
|
||||
Same as AsEigenDSizes() but allows for NDIMS > dims() in which case we pad the rest of the sizes with 1.
|
||||
Same as ` AsEigenDSizes() ` but allows for `NDIMS > dims() ` in which case we pad the rest of the sizes with 1.
|
||||
|
||||
#### `TensorShapeIter tensorflow::TensorShape::begin() const` <a class="md-anchor" id="TensorShapeIter_tensorflow_TensorShape_begin"></a>
|
||||
|
||||
@ -191,6 +190,6 @@ For error messages.
|
||||
|
||||
#### `static bool tensorflow::TensorShape::IsValid(const TensorShapeProto &proto)` <a class="md-anchor" id="static_bool_tensorflow_TensorShape_IsValid"></a>
|
||||
|
||||
Returns true iff "proto" is a valid tensor shape.
|
||||
Returns `true` iff `proto` is a valid tensor shape.
|
||||
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Class `tensorflow::TensorShapeUtils` <a class="md-anchor" id="AUTOGENERATED-class--tensorflow--tensorshapeutils-"></a>
|
||||
|
||||
Static helper routines for TensorShape . Includes a few common predicates on a tensor shape.
|
||||
Static helper routines for ` TensorShape `. Includes a few common predicates on a tensor shape.
|
||||
|
||||
|
||||
|
||||
@ -14,7 +14,7 @@ Static helper routines for TensorShape . Includes a few common predicates on a t
|
||||
* [`static bool tensorflow::TensorShapeUtils::IsMatrix(const TensorShape &shape)`](#static_bool_tensorflow_TensorShapeUtils_IsMatrix)
|
||||
* [`static bool tensorflow::TensorShapeUtils::IsMatrixOrHigher(const TensorShape &shape)`](#static_bool_tensorflow_TensorShapeUtils_IsMatrixOrHigher)
|
||||
* [`static TensorShape tensorflow::TensorShapeUtils::MakeShape(const T *dims, int n)`](#static_TensorShape_tensorflow_TensorShapeUtils_MakeShape)
|
||||
* Returns a TensorShape whose dimensions are dims[0], dims[1], ..., dims[n-1].
|
||||
* Returns a ` TensorShape ` whose dimensions are `dims[0]`, `dims[1]`, ..., `dims[n-1]`.
|
||||
* [`static string tensorflow::TensorShapeUtils::ShapeListString(const gtl::ArraySlice< TensorShape > &shapes)`](#static_string_tensorflow_TensorShapeUtils_ShapeListString)
|
||||
* [`static bool tensorflow::TensorShapeUtils::StartsWith(const TensorShape &shape0, const TensorShape &shape1)`](#static_bool_tensorflow_TensorShapeUtils_StartsWith)
|
||||
|
||||
@ -64,7 +64,7 @@ Static helper routines for TensorShape . Includes a few common predicates on a t
|
||||
|
||||
#### `static TensorShape tensorflow::TensorShapeUtils::MakeShape(const T *dims, int n)` <a class="md-anchor" id="static_TensorShape_tensorflow_TensorShapeUtils_MakeShape"></a>
|
||||
|
||||
Returns a TensorShape whose dimensions are dims[0], dims[1], ..., dims[n-1].
|
||||
Returns a ` TensorShape ` whose dimensions are `dims[0]`, `dims[1]`, ..., `dims[n-1]`.
|
||||
|
||||
|
||||
|
||||
|
@ -6,8 +6,9 @@ and the easiest to use, but the C++ API may offer some performance advantages
|
||||
in graph execution, and supports deployment to small devices such as Android.
|
||||
|
||||
Over time, we hope that the TensorFlow community will develop front ends for
|
||||
languages like Go, Java, Javascript, Lua R, and perhaps others. With SWIG, it's
|
||||
relatively easy to contribute a TensorFlow interface to your favorite language.
|
||||
languages like Go, Java, JavaScript, Lua R, and perhaps others. With
|
||||
[SWIG](http://swig.org), it's relatively easy to develop a TensorFlow interface
|
||||
for your favorite language.
|
||||
|
||||
Note: Many practical aspects of usage are covered in the Mechanics tab, and
|
||||
some additional documentation not specific to any particular language API is
|
||||
|
@ -102,7 +102,8 @@ Inside the virtualenv, install TensorFlow:
|
||||
You can then run your TensorFlow program like:
|
||||
|
||||
```bash
|
||||
(tensorflow)$ python tensorflow/models/image/mnist/convolutional.py
|
||||
(tensorflow)$ cd tensorflow/models/image/mnist
|
||||
(tensorflow)$ python convolutional.py
|
||||
|
||||
# When you are done using TensorFlow:
|
||||
(tensorflow)$ deactivate # Deactivate the virtualenv
|
||||
@ -304,10 +305,11 @@ $ pip install /tmp/tensorflow_pkg/tensorflow-0.5.0-cp27-none-linux_x86_64.whl
|
||||
|
||||
## Train your first TensorFlow neural net model <a class="md-anchor" id="AUTOGENERATED-train-your-first-tensorflow-neural-net-model"></a>
|
||||
|
||||
From the root of your source tree, run:
|
||||
Starting from the root of your source tree, run:
|
||||
|
||||
```python
|
||||
$ python tensorflow/models/image/mnist/convolutional.py
|
||||
$ cd tensorflow/models/image/mnist
|
||||
$ python convolutional.py
|
||||
Succesfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.
|
||||
Succesfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.
|
||||
Succesfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.
|
||||
@ -360,7 +362,7 @@ Solution: make sure you are using Python 2.7.
|
||||
If you encounter:
|
||||
|
||||
```python
|
||||
import six.moves.copyreg as copyreg
|
||||
import six.moves.copyreg as copyreg
|
||||
|
||||
ImportError: No module named copyreg
|
||||
```
|
||||
@ -386,3 +388,26 @@ There are several ways to fix this:
|
||||
|
||||
|
||||
|
||||
If you encounter:
|
||||
|
||||
```
|
||||
>>> import tensorflow as tf
|
||||
Traceback (most recent call last):
|
||||
File "<stdin>", line 1, in <module>
|
||||
File "/usr/local/lib/python2.7/site-packages/tensorflow/__init__.py", line 4, in <module>
|
||||
from tensorflow.python import *
|
||||
File "/usr/local/lib/python2.7/site-packages/tensorflow/python/__init__.py", line 13, in <module>
|
||||
from tensorflow.core.framework.graph_pb2 import *
|
||||
...
|
||||
File "/usr/local/lib/python2.7/site-packages/tensorflow/core/framework/tensor_shape_pb2.py", line 22, in <module>
|
||||
serialized_pb=_b('\n,tensorflow/core/framework/tensor_shape.proto\x12\ntensorflow\"d\n\x10TensorShapeProto\x12-\n\x03\x64im\x18\x02 \x03(\x0b\x32 .tensorflow.TensorShapeProto.Dim\x1a!\n\x03\x44im\x12\x0c\n\x04size\x18\x01 \x01(\x03\x12\x0c\n\x04name\x18\x02 \x01(\tb\x06proto3')
|
||||
TypeError: __init__() got an unexpected keyword argument 'syntax'
|
||||
```
|
||||
|
||||
This is due to a conflict between protobuf versions (we require protobuf 3.0.0).
|
||||
The best current solution is to make sure older versions of protobuf are not
|
||||
installed, such as:
|
||||
|
||||
```bash
|
||||
brew reinstall --devel protobuf
|
||||
```
|
||||
|
0
tensorflow/g3doc/how_tos/__init__.py
Executable file → Normal file
0
tensorflow/g3doc/how_tos/__init__.py
Executable file → Normal file
0
tensorflow/g3doc/how_tos/adding_an_op/__init__.py
Executable file → Normal file
0
tensorflow/g3doc/how_tos/adding_an_op/__init__.py
Executable file → Normal file
0
tensorflow/g3doc/how_tos/reading_data/__init__.py
Executable file → Normal file
0
tensorflow/g3doc/how_tos/reading_data/__init__.py
Executable file → Normal file
@ -322,10 +322,10 @@ with tf.variable_scope("foo", initializer=tf.constant_initializer(0.4)):
|
||||
w = tf.get_variable("w", [1], initializer=tf.constant_initializer(0.3)):
|
||||
assert w.eval() == 0.3 # Specific initializer overrides the default.
|
||||
with tf.variable_scope("bar"):
|
||||
v = get_variable("v", [1])
|
||||
v = tf.get_variable("v", [1])
|
||||
assert v.eval() == 0.4 # Inherited default initializer.
|
||||
with tf.variable_scope("baz", initializer=tf.constant_initializer(0.2)):
|
||||
v = get_variable("v", [1])
|
||||
v = tf.get_variable("v", [1])
|
||||
assert v.eval() == 0.2 # Changed default initializer.
|
||||
```
|
||||
|
||||
|
0
tensorflow/g3doc/tutorials/__init__.py
Executable file → Normal file
0
tensorflow/g3doc/tutorials/__init__.py
Executable file → Normal file
0
tensorflow/g3doc/tutorials/mnist/__init__.py
Executable file → Normal file
0
tensorflow/g3doc/tutorials/mnist/__init__.py
Executable file → Normal file
0
tensorflow/g3doc/tutorials/word2vec/__init__.py
Executable file → Normal file
0
tensorflow/g3doc/tutorials/word2vec/__init__.py
Executable file → Normal file
0
tensorflow/models/__init__.py
Normal file
0
tensorflow/models/__init__.py
Normal file
@ -7,7 +7,7 @@ ICLR 2013.
|
||||
Detailed instructions on how to get started and use them are available in the
|
||||
tutorials. Brief instructions are below.
|
||||
|
||||
* [Word2Vec Tutorial](http://tensorflow.org/tutorials/word2vec/)
|
||||
* [Word2Vec Tutorial](http://tensorflow.org/tutorials/word2vec/index.md)
|
||||
|
||||
To download the example text and evaluation data:
|
||||
|
||||
|
0
tensorflow/models/embedding/__init__.py
Executable file → Normal file
0
tensorflow/models/embedding/__init__.py
Executable file → Normal file
0
tensorflow/models/image/__init__.py
Normal file
0
tensorflow/models/image/__init__.py
Normal file
0
tensorflow/models/image/alexnet/__init__.py
Executable file → Normal file
0
tensorflow/models/image/alexnet/__init__.py
Executable file → Normal file
0
tensorflow/models/image/cifar10/__init__.py
Executable file → Normal file
0
tensorflow/models/image/cifar10/__init__.py
Executable file → Normal file
0
tensorflow/models/image/mnist/__init__.py
Executable file → Normal file
0
tensorflow/models/image/mnist/__init__.py
Executable file → Normal file
@ -2,8 +2,8 @@ This directory contains functions for creating recurrent neural networks
|
||||
and sequence-to-sequence models. Detailed instructions on how to get started
|
||||
and use them are available in the tutorials.
|
||||
|
||||
* [RNN Tutorial](http://tensorflow.org/tutorials/recurrent/)
|
||||
* [Sequence-to-Sequence Tutorial](http://tensorflow.org/tutorials/seq2seq/)
|
||||
* [RNN Tutorial](http://tensorflow.org/tutorials/recurrent/index.md)
|
||||
* [Sequence-to-Sequence Tutorial](http://tensorflow.org/tutorials/seq2seq/index.md)
|
||||
|
||||
Here is a short overview of what is in this directory.
|
||||
|
||||
|
0
tensorflow/models/rnn/ptb/__init__.py
Executable file → Normal file
0
tensorflow/models/rnn/ptb/__init__.py
Executable file → Normal file
@ -120,7 +120,7 @@ class PTBModel(object):
|
||||
tf.get_variable("softmax_w", [size, vocab_size]),
|
||||
tf.get_variable("softmax_b", [vocab_size]))
|
||||
loss = seq2seq.sequence_loss_by_example([logits],
|
||||
[tf.reshape(self._targets, -1)],
|
||||
[tf.reshape(self._targets, [-1])],
|
||||
[tf.ones([batch_size * num_steps])],
|
||||
vocab_size)
|
||||
self._cost = cost = tf.reduce_sum(loss) / batch_size
|
||||
|
0
tensorflow/models/rnn/translate/__init__.py
Executable file → Normal file
0
tensorflow/models/rnn/translate/__init__.py
Executable file → Normal file
0
tensorflow/opensource_only/__init__.py
Executable file → Normal file
0
tensorflow/opensource_only/__init__.py
Executable file → Normal file
0
tensorflow/opensource_only/pip_package/__init__.py
Executable file → Normal file
0
tensorflow/opensource_only/pip_package/__init__.py
Executable file → Normal file
0
tensorflow/python/client/__init__.py
Executable file → Normal file
0
tensorflow/python/client/__init__.py
Executable file → Normal file
0
tensorflow/python/framework/__init__.py
Executable file → Normal file
0
tensorflow/python/framework/__init__.py
Executable file → Normal file
@ -156,47 +156,66 @@ _TENSOR_CONTENT_TYPES = frozenset([
|
||||
])
|
||||
|
||||
|
||||
class _Message(object):
|
||||
|
||||
def __init__(self, message):
|
||||
self._message = message
|
||||
|
||||
def __repr__(self):
|
||||
return self._message
|
||||
|
||||
|
||||
def _FirstNotNone(l):
|
||||
for x in l:
|
||||
if x is not None:
|
||||
return x
|
||||
if isinstance(x, ops.Tensor):
|
||||
return _Message("list containing Tensors")
|
||||
else:
|
||||
return x
|
||||
return None
|
||||
|
||||
|
||||
def _NotNone(v):
|
||||
if v is None:
|
||||
return _Message("None")
|
||||
else:
|
||||
return v
|
||||
|
||||
|
||||
def _FilterInt(v):
|
||||
if isinstance(v, (list, tuple)):
|
||||
return _FirstNotNone([_FilterInt(x) for x in v])
|
||||
return None if isinstance(v, numbers.Integral) else repr(v)
|
||||
return None if isinstance(v, numbers.Integral) else _NotNone(v)
|
||||
|
||||
|
||||
def _FilterFloat(v):
|
||||
if isinstance(v, (list, tuple)):
|
||||
return _FirstNotNone([_FilterFloat(x) for x in v])
|
||||
return None if isinstance(v, numbers.Real) else repr(v)
|
||||
return None if isinstance(v, numbers.Real) else _NotNone(v)
|
||||
|
||||
|
||||
def _FilterComplex(v):
|
||||
if isinstance(v, (list, tuple)):
|
||||
return _FirstNotNone([_FilterComplex(x) for x in v])
|
||||
return None if isinstance(v, numbers.Complex) else repr(v)
|
||||
return None if isinstance(v, numbers.Complex) else _NotNone(v)
|
||||
|
||||
|
||||
def _FilterStr(v):
|
||||
if isinstance(v, (list, tuple)):
|
||||
return _FirstNotNone([_FilterStr(x) for x in v])
|
||||
return None if isinstance(v, basestring) else repr(v)
|
||||
return None if isinstance(v, basestring) else _NotNone(v)
|
||||
|
||||
|
||||
def _FilterBool(v):
|
||||
if isinstance(v, (list, tuple)):
|
||||
return _FirstNotNone([_FilterBool(x) for x in v])
|
||||
return None if isinstance(v, bool) else repr(v)
|
||||
return None if isinstance(v, bool) else _NotNone(v)
|
||||
|
||||
|
||||
def _FilterNotTensor(v):
|
||||
if isinstance(v, (list, tuple)):
|
||||
return _FirstNotNone([_FilterNotTensor(x) for x in v])
|
||||
return repr(v) if isinstance(v, ops.Tensor) else None
|
||||
return str(v) if isinstance(v, ops.Tensor) else None
|
||||
|
||||
|
||||
_TF_TO_IS_OK = {
|
||||
@ -224,7 +243,7 @@ def _AssertCompatible(values, dtype):
|
||||
raise TypeError("List of Tensors when single Tensor expected")
|
||||
else:
|
||||
raise TypeError("Expected %s, got %s instead." %
|
||||
(dtype.name, mismatch))
|
||||
(dtype.name, repr(mismatch)))
|
||||
|
||||
|
||||
def make_tensor_proto(values, dtype=None, shape=None):
|
||||
|
0
tensorflow/python/kernel_tests/__init__.py
Executable file → Normal file
0
tensorflow/python/kernel_tests/__init__.py
Executable file → Normal file
@ -323,6 +323,34 @@ class FIFOQueueTest(tf.test.TestCase):
|
||||
enqueue_op.run()
|
||||
self.assertAllEqual(dequeued_t.eval(), elems)
|
||||
|
||||
def testEnqueueWrongShape(self):
|
||||
with self.test_session() as sess:
|
||||
q = tf.FIFOQueue(10, (tf.int32, tf.int32), ((2, 2), (3, 3)))
|
||||
elems_ok = np.array([1] * 4).reshape((2, 2)).astype(np.int32)
|
||||
elems_bad = tf.placeholder(tf.int32)
|
||||
enqueue_op = q.enqueue((elems_ok, elems_bad))
|
||||
with self.assertRaisesRegexp(
|
||||
tf.errors.InvalidArgumentError, r"Expected \[3,3\], got \[3,4\]"):
|
||||
sess.run([enqueue_op],
|
||||
feed_dict={elems_bad: np.array([1] * 12).reshape((3, 4))})
|
||||
sess.run([enqueue_op],
|
||||
feed_dict={elems_bad: np.array([1] * 12).reshape((3, 4))})
|
||||
|
||||
def testEnqueueDequeueManyWrongShape(self):
|
||||
with self.test_session() as sess:
|
||||
q = tf.FIFOQueue(10, (tf.int32, tf.int32), ((2, 2), (3, 3)))
|
||||
elems_ok = np.array([1] * 8).reshape((2, 2, 2)).astype(np.int32)
|
||||
elems_bad = tf.placeholder(tf.int32)
|
||||
enqueue_op = q.enqueue_many((elems_ok, elems_bad))
|
||||
dequeued_t = q.dequeue_many(2)
|
||||
with self.assertRaisesRegexp(
|
||||
tf.errors.InvalidArgumentError,
|
||||
"Shape mismatch in tuple component 1. "
|
||||
r"Expected \[2,3,3\], got \[2,3,4\]"):
|
||||
sess.run([enqueue_op],
|
||||
feed_dict={elems_bad: np.array([1] * 24).reshape((2, 3, 4))})
|
||||
dequeued_t.eval()
|
||||
|
||||
def testParallelEnqueueMany(self):
|
||||
with self.test_session() as sess:
|
||||
q = tf.FIFOQueue(1000, tf.float32, shapes=())
|
||||
|
0
tensorflow/python/lib/__init__.py
Executable file → Normal file
0
tensorflow/python/lib/__init__.py
Executable file → Normal file
0
tensorflow/python/lib/core/__init__.py
Executable file → Normal file
0
tensorflow/python/lib/core/__init__.py
Executable file → Normal file
0
tensorflow/python/lib/io/__init__.py
Executable file → Normal file
0
tensorflow/python/lib/io/__init__.py
Executable file → Normal file
0
tensorflow/python/ops/__init__.py
Executable file → Normal file
0
tensorflow/python/ops/__init__.py
Executable file → Normal file
@ -201,7 +201,9 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
|
||||
|
||||
with self.assertRaises(TypeError) as cm:
|
||||
self._lib.apply_op("Simple", a=[self.Tensor(types.int32)])
|
||||
self.assertStartsWith(cm.exception.message, "Expected int32, got")
|
||||
self.assertStartsWith(
|
||||
cm.exception.message,
|
||||
"Expected int32, got list containing Tensors instead.")
|
||||
|
||||
def testReservedInput(self):
|
||||
self._add_op("name: 'ReservedInput' "
|
||||
|
@ -185,10 +185,10 @@ def get_variable(name, shape=None, dtype=types.float32, initializer=None,
|
||||
|
||||
```python
|
||||
with tf.variable_scope("foo"):
|
||||
v = get_variable("v", [1]) # v.name == "foo/v:0"
|
||||
w = get_variable("w", [1]) # w.name == "foo/w:0"
|
||||
v = tf.get_variable("v", [1]) # v.name == "foo/v:0"
|
||||
w = tf.get_variable("w", [1]) # w.name == "foo/w:0"
|
||||
with tf.variable_scope("foo", reuse=True)
|
||||
v1 = get_variable("v") # The same as v above.
|
||||
v1 = tf.get_variable("v") # The same as v above.
|
||||
```
|
||||
|
||||
If initializer is `None` (the default), the default initializer passed in
|
||||
@ -240,7 +240,7 @@ def variable_scope(name_or_scope, reuse=None, initializer=None):
|
||||
|
||||
```python
|
||||
with tf.variable_scope("foo"):
|
||||
v = get_variable("v", [1])
|
||||
v = tf.get_variable("v", [1])
|
||||
with tf.variable_scope("foo", reuse=True):
|
||||
v1 = tf.get_variable("v", [1])
|
||||
assert v1 == v
|
||||
@ -250,7 +250,7 @@ def variable_scope(name_or_scope, reuse=None, initializer=None):
|
||||
|
||||
```python
|
||||
with tf.variable_scope("foo") as scope.
|
||||
v = get_variable("v", [1])
|
||||
v = tf.get_variable("v", [1])
|
||||
scope.reuse_variables()
|
||||
v1 = tf.get_variable("v", [1])
|
||||
assert v1 == v
|
||||
@ -261,7 +261,7 @@ def variable_scope(name_or_scope, reuse=None, initializer=None):
|
||||
|
||||
```python
|
||||
with tf.variable_scope("foo") as scope.
|
||||
v = get_variable("v", [1])
|
||||
v = tf.get_variable("v", [1])
|
||||
v1 = tf.get_variable("v", [1])
|
||||
# Raises ValueError("... v already exists ...").
|
||||
```
|
||||
@ -271,7 +271,7 @@ def variable_scope(name_or_scope, reuse=None, initializer=None):
|
||||
|
||||
```python
|
||||
with tf.variable_scope("foo", reuse=True):
|
||||
v = get_variable("v", [1])
|
||||
v = tf.get_variable("v", [1])
|
||||
# Raises ValueError("... v does not exists ...").
|
||||
```
|
||||
|
||||
|
0
tensorflow/python/platform/default/__init__.py
Executable file → Normal file
0
tensorflow/python/platform/default/__init__.py
Executable file → Normal file
0
tensorflow/python/summary/__init__.py
Executable file → Normal file
0
tensorflow/python/summary/__init__.py
Executable file → Normal file
0
tensorflow/python/summary/impl/__init__.py
Executable file → Normal file
0
tensorflow/python/summary/impl/__init__.py
Executable file → Normal file
0
tensorflow/python/training/__init__.py
Executable file → Normal file
0
tensorflow/python/training/__init__.py
Executable file → Normal file
@ -87,10 +87,12 @@ class SummaryWriter(object):
|
||||
This method wraps the provided summary in an `Event` procotol buffer
|
||||
and adds it to the event file.
|
||||
|
||||
You can pass the output of any summary op, as-is, to this function. You
|
||||
can also pass a `Summary` procotol buffer that you manufacture with your
|
||||
own data. This is commonly done to report evaluation results in event
|
||||
files.
|
||||
You can pass the result of evaluating any summary op, using
|
||||
[`Session.run()`](client.md#Session.run] or
|
||||
[`Tensor.eval()`](framework.md#Tensor.eval), to this
|
||||
function. Alternatively, you can pass a `tf.Summary` protocol
|
||||
buffer that you populate with your own data. The latter is
|
||||
commonly done to report evaluation results in event files.
|
||||
|
||||
Args:
|
||||
summary: A `Summary` protocol buffer, optionally serialized as a string.
|
||||
|
0
tensorflow/python/user_ops/__init__.py
Executable file → Normal file
0
tensorflow/python/user_ops/__init__.py
Executable file → Normal file
0
tensorflow/python/util/__init__.py
Executable file → Normal file
0
tensorflow/python/util/__init__.py
Executable file → Normal file
0
tensorflow/python/util/protobuf/__init__.py
Executable file → Normal file
0
tensorflow/python/util/protobuf/__init__.py
Executable file → Normal file
0
tensorflow/tensorboard/__init__.py
Executable file → Normal file
0
tensorflow/tensorboard/__init__.py
Executable file → Normal file
0
tensorflow/tensorboard/scripts/__init__.py
Normal file
0
tensorflow/tensorboard/scripts/__init__.py
Normal file
131
tensorflow/tensorboard/scripts/demo_from_server.py
Normal file
131
tensorflow/tensorboard/scripts/demo_from_server.py
Normal file
@ -0,0 +1,131 @@
|
||||
from __future__ import print_function
|
||||
import json
|
||||
import os
|
||||
import urllib2
|
||||
import requests
|
||||
import os.path
|
||||
import shutil
|
||||
|
||||
class TensorBoardStaticSerializer(object):
|
||||
"""Serialize all the routes from a TensorBoard server to static json."""
|
||||
|
||||
def __init__(self, host, port, path):
|
||||
self.server_address = '%s:%d/' % (host, port)
|
||||
EnsureDirectoryExists(path)
|
||||
self.path = path
|
||||
self.img_id = 0
|
||||
|
||||
def _QuoteOrNone(self, x):
|
||||
if x is None:
|
||||
return None
|
||||
else:
|
||||
return urllib2.quote(x)
|
||||
|
||||
def _RetrieveRoute(self, route, run=None, tag=None):
|
||||
"""Load route (possibly with run and tag), return the json."""
|
||||
r = self._SendRequest(route, run, tag)
|
||||
j = r.json()
|
||||
return j
|
||||
|
||||
def _SendRequest(self, route, run=None, tag=None):
|
||||
url = self.server_address + route
|
||||
run = self._QuoteOrNone(run)
|
||||
tag = self._QuoteOrNone(tag)
|
||||
if run is not None:
|
||||
url += '?run={}'.format(run)
|
||||
if tag is not None:
|
||||
url += '&tag={}'.format(tag)
|
||||
r = requests.get(url)
|
||||
if r.status_code != 200:
|
||||
raise IOError
|
||||
return r
|
||||
|
||||
def _SaveRouteJsonToDisk(self, data, route, run=None, tag=None):
|
||||
"""Save the route, run, tag result to a predictable spot on disk."""
|
||||
print('%s/%s/%s' % (route, run, tag))
|
||||
if run is not None:
|
||||
run = run.replace(' ', '_')
|
||||
if tag is not None:
|
||||
tag = tag.replace(' ', '_')
|
||||
tag = tag.replace('(', '_')
|
||||
tag = tag.replace(')', '_')
|
||||
components = [x for x in [self.path, route, run, tag] if x]
|
||||
path = os.path.join(*components) + '.json'
|
||||
EnsureDirectoryExists(os.path.dirname(path))
|
||||
with open(path, 'w') as f:
|
||||
json.dump(data, f)
|
||||
|
||||
def _RetrieveAndSave(self, route, run=None, tag=None):
|
||||
"""Retrieve data, and save it to disk."""
|
||||
data = self._RetrieveRoute(route, run, tag)
|
||||
self._SaveRouteJsonToDisk(data, route, run, tag)
|
||||
return data
|
||||
|
||||
def _SerializeImages(self, run, tag):
|
||||
"""Serialize all the images, and use ids not query parameters."""
|
||||
EnsureDirectoryExists(os.path.join(self.path, 'individualImage'))
|
||||
images = self._RetrieveRoute('images', run, tag)
|
||||
for im in images:
|
||||
q = im['query']
|
||||
im['query'] = self.img_id
|
||||
path = '%s/individualImage/%d.png' % (self.path, self.img_id)
|
||||
self.img_id += 1
|
||||
r = requests.get(self.server_address + 'individualImage?' + q)
|
||||
if r.status_code != 200:
|
||||
raise IOError
|
||||
with open(path, 'wb') as f:
|
||||
f.write(r.content)
|
||||
self._SaveRouteJsonToDisk(images, 'images', run, tag)
|
||||
|
||||
|
||||
def Run(self):
|
||||
"""Main method that loads and serializes everything."""
|
||||
runs = self._RetrieveAndSave('runs')
|
||||
for run, tag_type_to_tags in runs.iteritems():
|
||||
for tag_type, tags in tag_type_to_tags.iteritems():
|
||||
try:
|
||||
if tag_type == 'graph':
|
||||
if tags:
|
||||
r = self._SendRequest('graph', run, None)
|
||||
pbtxt = r.text
|
||||
fname = run.replace(' ', '_') + '.pbtxt'
|
||||
path = os.path.join(self.path, 'graph', fname)
|
||||
EnsureDirectoryExists(os.path.dirname(path))
|
||||
with open(path, 'w') as f:
|
||||
f.write(pbtxt)
|
||||
elif tag_type == 'images':
|
||||
for t in tags:
|
||||
self._SerializeImages(run, t)
|
||||
else:
|
||||
for t in tags:
|
||||
self._RetrieveAndSave(tag_type, run, t)
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
print('Retrieval failed for %s/%s/%s' % (tag_type, run, tag))
|
||||
print('Got error: ', e)
|
||||
print('continuing...')
|
||||
continue
|
||||
except IOError as e:
|
||||
print('Retrieval failed for %s/%s/%s' % (tag_type, run, tag))
|
||||
print('Got error: ', e)
|
||||
print('continuing...')
|
||||
continue
|
||||
|
||||
|
||||
def EnsureDirectoryExists(path):
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
|
||||
def main(unused_argv=None):
|
||||
target = '/tmp/tensorboard_demo_data'
|
||||
port = 6006
|
||||
host = 'http://localhost'
|
||||
if os.path.exists(target):
|
||||
if os.path.isdir(target):
|
||||
shutil.rmtree(target)
|
||||
else:
|
||||
os.remove(target)
|
||||
x = TensorBoardStaticSerializer(host, port, target)
|
||||
x.Run()
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
0
tensorflow/tools/__init__.py
Executable file → Normal file
0
tensorflow/tools/__init__.py
Executable file → Normal file
@ -1,4 +1,4 @@
|
||||
FROM b.gcr.io/tensorflow-testing/tensorflow
|
||||
FROM b.gcr.io/tensorflow/tensorflow
|
||||
|
||||
MAINTAINER Craig Citro <craigcitro@google.com>
|
||||
|
||||
@ -63,6 +63,6 @@ COPY tensorflow /tensorflow
|
||||
RUN bazel clean && \
|
||||
bazel build -c opt tensorflow/tools/pip_package:build_pip_package && \
|
||||
bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/pip && \
|
||||
pip install /tmp/pip/tensorflow-*.whl
|
||||
pip install --upgrade /tmp/pip/tensorflow-*.whl
|
||||
|
||||
WORKDIR /root
|
||||
|
@ -1,8 +0,0 @@
|
||||
FROM b.gcr.io/tensorflow-testing/tensorflow-gpu-flat
|
||||
|
||||
MAINTAINER Craig Citro <craigcitro@google.com>
|
||||
|
||||
WORKDIR /root
|
||||
EXPOSE 6006
|
||||
EXPOSE 8888
|
||||
RUN ["/bin/bash"]
|
@ -21,14 +21,12 @@ RUN ./configure
|
||||
|
||||
# Now we build
|
||||
RUN bazel clean && \
|
||||
bazel build -c opt --config=cuda tensorflow/tools/pip_package:build_pip_package 2>&1 | tee -a /tmp/bazel.log && \
|
||||
rm -rf /tmp/pip && \
|
||||
bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/pip && \
|
||||
pip install /tmp/pip/tensorflow-*.whl && \
|
||||
bazel clean
|
||||
bazel build -c opt --config=cuda tensorflow/tools/pip_package:build_pip_package
|
||||
|
||||
RUN rm -rf /usr/local/cuda && \
|
||||
rm -rf /usr/share/nvidia && \
|
||||
rm -rf /root/.cache/
|
||||
RUN rm -rf /tmp/pip && \
|
||||
bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/pip && \
|
||||
pip install --upgrade /tmp/pip/tensorflow-*.whl
|
||||
|
||||
RUN rm -rf /usr/local/cuda
|
||||
|
||||
RUN ["/bin/bash"]
|
||||
|
@ -31,9 +31,6 @@ RUN curl -O https://bootstrap.pypa.io/get-pip.py && \
|
||||
RUN pip --no-cache-dir install ipykernel && \
|
||||
python -m ipykernel.kernelspec
|
||||
|
||||
# Add any notebooks in this directory.
|
||||
COPY notebooks/*.ipynb /notebooks/
|
||||
|
||||
# Set up our notebook config.
|
||||
COPY jupyter_notebook_config.py /root/.jupyter/
|
||||
|
||||
@ -42,15 +39,9 @@ COPY jupyter_notebook_config.py /root/.jupyter/
|
||||
# We just add a little wrapper script.
|
||||
COPY run_jupyter.sh /
|
||||
|
||||
# Set the workdir so we see notebooks on the IPython landing page.
|
||||
WORKDIR /notebooks
|
||||
|
||||
# These are temporary while we sort out the GPU dependency.
|
||||
ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH
|
||||
|
||||
# TensorBoard
|
||||
EXPOSE 6006
|
||||
# IPython
|
||||
EXPOSE 8888
|
||||
|
||||
CMD ["/run_jupyter.sh"]
|
||||
CMD ["/bin/bash"]
|
||||
|
83
tensorflow/tools/docker/README.md
Normal file
83
tensorflow/tools/docker/README.md
Normal file
@ -0,0 +1,83 @@
|
||||
# Using TensorFlow via Docker
|
||||
|
||||
This directory contains `Dockerfile`s to make it easy to get up and running with
|
||||
TensorFlow via [Docker](http://www.docker.com/).
|
||||
|
||||
## Installing Docker
|
||||
|
||||
General installation instructions are
|
||||
[on the Docker site](https://docs.docker.com/installation/), but we give some
|
||||
quick links here:
|
||||
|
||||
* [OSX](https://docs.docker.com/installation/mac/): [docker toolbox](https://www.docker.com/toolbox)
|
||||
* [ubuntu](https://docs.docker.com/installation/ubuntulinux/)
|
||||
|
||||
## Which containers exist?
|
||||
|
||||
We currently maintain three Docker container images:
|
||||
|
||||
* `b.gcr.io/tensorflow/tensorflow`, which is a minimal VM with TensorFlow and
|
||||
all dependencies.
|
||||
|
||||
* `b.gcr.io/tensorflow/tensorflow-full`, which contains a full source
|
||||
distribution and all required libraries to build and run TensorFlow from
|
||||
source.
|
||||
|
||||
* `b.gcr.io/tensorflow/tensorflow-full-gpu`, which is the same as the previous
|
||||
container, but built with GPU support.
|
||||
|
||||
## Running the container
|
||||
|
||||
Each of the containers is published to a Docker registry; for the non-GPU
|
||||
containers, running is as simple as
|
||||
|
||||
$ docker run -it b.gcr.io/tensorflow/tensorflow
|
||||
|
||||
For the container with GPU support, we require the user to make the appropriate
|
||||
NVidia libraries available on their system, as well as providing mappings so
|
||||
that the container can see the host's GPU. For most purposes, this can be
|
||||
accomplished via
|
||||
|
||||
$ export CUDA_SO=$(\ls /usr/lib/x86_64-linux-gnu/libcuda* | xargs -I{} echo '-v {}:{}')
|
||||
$ export DEVICES=$(\ls /dev/nvidia* | xargs -I{} echo '--device {}:{}')
|
||||
$ export CUDA_SRCS="-v /usr/local/cuda:/usr/local/cuda -v /usr/share/nvidia:/usr/share/nvidia"
|
||||
$ docker run -it $CUDA_SO $CUDA_SRCS $DEVICES b.gcr.io/tensorflow/tensorflow-full-gpu
|
||||
|
||||
Alternately, you can use the `docker_run_gpu.sh` script in this directory.
|
||||
|
||||
## Rebuilding the containers
|
||||
|
||||
### tensorflow/tensorflow
|
||||
|
||||
This one requires no extra setup -- just
|
||||
|
||||
$ docker build -t $USER/tensorflow -f Dockerfile.lite .
|
||||
|
||||
### tensorflow/tensorflow-full
|
||||
|
||||
This one requires a copy of the tensorflow source tree at `./tensorflow` (since
|
||||
we don't keep the `Dockerfile`s at the top of the tree). With that in place,
|
||||
just run
|
||||
|
||||
$ git clone https://github.com/tensorflow/tensorflow
|
||||
$ docker build -t $USER/tensorflow-full -f Dockerfile.cpu .
|
||||
|
||||
### tensorflow/tensorflow-gpu
|
||||
|
||||
This one requires a few steps, since we need the NVidia headers to be available
|
||||
*during* the build step, but we don't want them included in the final container
|
||||
image. We need to start by installing the NVidia libraries as described in the
|
||||
[CUDA setup instructions](/get_started/os_setup.md#install_cuda). With that
|
||||
complete, we can build via
|
||||
|
||||
$ cp -a /usr/local/cuda .
|
||||
$ docker build -t $USER/tensorflow-gpu-base -f Dockerfile.gpu_base .
|
||||
# Flatten the image
|
||||
$ export TC=$(docker create $USER/tensorflow-gpu-base)
|
||||
$ docker export $TC | docker import - $USER/tensorflow-gpu-flat
|
||||
$ docker rm $TC
|
||||
$ export TC=$(docker create $USER/tensorflow-gpu-flat /bin/bash)
|
||||
$ docker commit --change='CMD ["/bin/bash"]' --change='ENV CUDA_PATH /usr/local/cuda' --change='ENV LD_LIBRARY_PATH /usr/local/cuda/lib64' --change='WORKDIR /root' $TC $USER/tensorflow-full-gpu
|
||||
$ docker rm $TC
|
||||
|
||||
This final image is a full TensorFlow image with GPU support.
|
0
tensorflow/tools/docker/__init__.py
Executable file → Normal file
0
tensorflow/tools/docker/__init__.py
Executable file → Normal file
23
tensorflow/tools/docker/docker_run_gpu.sh
Executable file
23
tensorflow/tools/docker/docker_run_gpu.sh
Executable file
@ -0,0 +1,23 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
export CUDA_HOME=${CUDA_HOME:-/usr/local/cuda}
|
||||
|
||||
if [ ! -d ${CUDA_HOME}/lib64 ]; then
|
||||
echo "Failed to locate CUDA libs at ${CUDA_HOME}/lib64."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export CUDA_SO=$(\ls /usr/lib/x86_64-linux-gnu/libcuda* | \
|
||||
xargs -I{} echo '-v {}:{}')
|
||||
export DEVICES=$(\ls /dev/nvidia* | \
|
||||
xargs -I{} echo '--device {}:{}')
|
||||
export CUDA_SRCS="-v ${CUDA_HOME}:${CUDA_HOME} -v /usr/share/nvidia:/usr/share/nvidia"
|
||||
|
||||
if [[ "${DEVICES}" = "" ]]; then
|
||||
echo "Failed to locate NVidia device(s). Did you want the non-GPU container?"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
docker run -it $CUDA_SO $CUDA_SRCS $DEVICES b.gcr.io/tensorflow/tensorflow-full-gpu "$@"
|
@ -11,7 +11,7 @@ function main() {
|
||||
DEST=$1
|
||||
TMPDIR=$(mktemp -d -t tmp.XXXXXXXXXX)
|
||||
|
||||
echo `date` : "=== Using tmpdir: ${TMPDIR}"
|
||||
echo $(date) : "=== Using tmpdir: ${TMPDIR}"
|
||||
|
||||
if [ ! -d bazel-bin/tensorflow ]; then
|
||||
echo "Could not find bazel-bin. Did you run from the root of the build tree?"
|
||||
@ -26,13 +26,13 @@ function main() {
|
||||
cp tensorflow/tools/pip_package/setup.py ${TMPDIR}
|
||||
pushd ${TMPDIR}
|
||||
rm -f MANIFEST
|
||||
echo `date` : "=== Building wheel"
|
||||
python setup.py sdist bdist_wheel >/dev/null
|
||||
echo $(date) : "=== Building wheel"
|
||||
python setup.py bdist_wheel >/dev/null
|
||||
mkdir -p ${DEST}
|
||||
cp dist/* ${DEST}
|
||||
popd
|
||||
rm -rf ${TMPDIR}
|
||||
echo `date` : "=== Output wheel file is in: ${DEST}"
|
||||
echo $(date) : "=== Output wheel file is in: ${DEST}"
|
||||
}
|
||||
|
||||
main "$@"
|
||||
|
Loading…
Reference in New Issue
Block a user