First step of moving files out of tensorflow/core/public/. Here
we copy the original files to their new location and make the public/ versions #include the new location. Once all references are updated to point to the new location, we can delete the originals in public/. Change: 112622561
This commit is contained in:
parent
6cc49690ef
commit
db7478e899
@ -15,6 +15,14 @@
|
||||
example, the shape argument to `tf.reshape` can't be a scalar anymore). The
|
||||
open source release was already scalar strict, so outside Google `IsScalar`
|
||||
and `IsVector` are exact replacements.
|
||||
* The following files are being removed from `tensorflow/core/public/`:
|
||||
* `env.h` -> `../platform/env.h`
|
||||
* `status.h` -> `../lib/core/status.h`
|
||||
* `tensor.h` -> `../framework/tensor.h`
|
||||
* `tensor_shape.h` -> `../framework/tensor_shape.h`
|
||||
* `partial_tensor_shape.h` -> `../framework/partial_tensor_shape.h`
|
||||
* `tensorflow_server.h` deleted
|
||||
|
||||
|
||||
## Bug fixes
|
||||
|
||||
|
@ -91,6 +91,7 @@ cc_library(
|
||||
"lib/core/command_line_flags.h", # TODO(vrv): Delete.
|
||||
"lib/core/errors.h",
|
||||
"lib/core/notification.h",
|
||||
"lib/core/status.h",
|
||||
"lib/core/stringpiece.h",
|
||||
"lib/core/threadpool.h",
|
||||
"lib/gtl/array_slice.h",
|
||||
@ -115,6 +116,7 @@ cc_library(
|
||||
"lib/strings/str_util.h", # TODO(josh11b): make internal
|
||||
"lib/strings/strcat.h",
|
||||
"lib/strings/stringprintf.h",
|
||||
"platform/env.h",
|
||||
"platform/host_info.h", # TODO(josh11b): make internal
|
||||
"platform/init_main.h",
|
||||
"platform/logging.h",
|
||||
@ -126,8 +128,8 @@ cc_library(
|
||||
"platform/regexp.h",
|
||||
"platform/thread_annotations.h",
|
||||
"platform/types.h",
|
||||
"public/env.h", # TODO(josh11b): move to platform/
|
||||
"public/status.h", # TODO(josh11b): move to lib/core/
|
||||
"public/env.h", # Deprecated, use platform/env.h instead
|
||||
"public/status.h", # Deprecated: use lib/core/status.h instead
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":lib_internal"],
|
||||
@ -178,18 +180,21 @@ tf_cuda_library(
|
||||
"framework/op_def_util.h",
|
||||
"framework/op_gen_lib.h",
|
||||
"framework/op_kernel.h",
|
||||
"framework/partial_tensor_shape.h",
|
||||
"framework/queue_interface.h",
|
||||
"framework/reader_interface.h",
|
||||
"framework/reader_op_kernel.h",
|
||||
"framework/register_types.h",
|
||||
"framework/resource_mgr.h",
|
||||
"framework/tensor.h",
|
||||
"framework/tensor_shape.h",
|
||||
"framework/tensor_slice.h",
|
||||
"framework/tensor_types.h",
|
||||
"framework/tensor_util.h",
|
||||
"framework/type_traits.h",
|
||||
"framework/type_index.h",
|
||||
"framework/types.h",
|
||||
# TODO(josh11b): Move these from public/ to framework/
|
||||
# Deprecated. TODO(josh11b): Use the framework/ versions instead.
|
||||
"public/partial_tensor_shape.h",
|
||||
"public/tensor.h",
|
||||
"public/tensor_shape.h",
|
||||
@ -635,6 +640,8 @@ cc_library(
|
||||
"//tensorflow/core:android_srcs",
|
||||
],
|
||||
hdrs = [
|
||||
"framework/tensor.h",
|
||||
"platform/env.h",
|
||||
"platform/logging.h",
|
||||
"platform/port.h",
|
||||
"public/env.h",
|
||||
@ -670,8 +677,8 @@ cc_library(
|
||||
"lib/**/*.cc",
|
||||
"platform/*.h",
|
||||
"platform/*.cc",
|
||||
"public/env.h", # TODO(josh11b): move to platform/
|
||||
"public/status.h", # TODO(josh11b): move to lib/core/
|
||||
"public/env.h", # TODO(josh11b): delete this
|
||||
"public/status.h", # TODO(josh11b): delete this
|
||||
] + tf_additional_lib_srcs(),
|
||||
exclude = [
|
||||
"**/*test*",
|
||||
@ -715,7 +722,8 @@ tf_cuda_library(
|
||||
"framework/**/*.cc",
|
||||
"util/**/*.h",
|
||||
"util/**/*.cc",
|
||||
# TODO(josh11b): Move these from public/ to framework/
|
||||
# TODO(josh11b): Delete these once everyone has switched to the
|
||||
# framework/ versions.
|
||||
"public/partial_tensor_shape.h",
|
||||
"public/tensor.h",
|
||||
"public/tensor_shape.h",
|
||||
|
170
tensorflow/core/framework/partial_tensor_shape.h
Normal file
170
tensorflow/core/framework/partial_tensor_shape.h
Normal file
@ -0,0 +1,170 @@
|
||||
/* Copyright 2015 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_FRAMEWORK_PARTIAL_TENSOR_SHAPE_H_
|
||||
#define TENSORFLOW_CORE_FRAMEWORK_PARTIAL_TENSOR_SHAPE_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class PartialTensorShapeIter; // Declared below
|
||||
|
||||
/// Manages the partially known dimensions of a Tensor and their sizes.
|
||||
class PartialTensorShape {
|
||||
public:
|
||||
/// \brief Construct a `PartialTensorShape` from the provided sizes.
|
||||
/// REQUIRES: `dim_sizes[i] >= 0`
|
||||
explicit PartialTensorShape(gtl::ArraySlice<int64> dim_sizes);
|
||||
PartialTensorShape(std::initializer_list<int64> dim_sizes)
|
||||
: PartialTensorShape(gtl::ArraySlice<int64>(dim_sizes)) {}
|
||||
|
||||
/// REQUIRES: `IsValid(proto)`
|
||||
explicit PartialTensorShape(const TensorShapeProto& proto);
|
||||
|
||||
/// Returns `true` iff `proto` is a valid partial tensor shape.
|
||||
static bool IsValid(const TensorShapeProto& proto);
|
||||
|
||||
/// Returns `OK` iff `proto` is a valid tensor shape, and a descriptive error
|
||||
/// status otherwise.
|
||||
static Status IsValidShape(const TensorShapeProto& proto);
|
||||
|
||||
/// Add a dimension to the end ("inner-most"), returns a new
|
||||
/// PartialTensorShape.
|
||||
/// REQUIRES: `size >= -1`, where -1 means unknown
|
||||
PartialTensorShape Concatenate(int64 size) const;
|
||||
|
||||
/// Appends all the dimensions from `shape`. Returns a new
|
||||
/// PartialTensorShape.
|
||||
PartialTensorShape Concatenate(const PartialTensorShape& shape) const;
|
||||
|
||||
/// Merges all the dimensions from `shape`. Returns
|
||||
/// `InvalidArgument` error if either `shape` has a different rank
|
||||
/// or if any of the dimensions are incompatible.
|
||||
Status MergeWith(const PartialTensorShape& shape,
|
||||
PartialTensorShape* result) const;
|
||||
|
||||
/// Return the number of dimensions in the tensor.
|
||||
int dims() const { return dim_sizes_.size(); }
|
||||
|
||||
/// Return true iff the rank and all of the dimensions are well defined
|
||||
bool IsFullyDefined() const;
|
||||
|
||||
/// Return true iff the ranks match, and if the
|
||||
/// dimensions all either match or one is unknown.
|
||||
bool IsCompatibleWith(const PartialTensorShape& shape) const;
|
||||
|
||||
/// Return true iff the dimensions of `shape` are compatible with
|
||||
/// `*this`.
|
||||
bool IsCompatibleWith(const TensorShape& shape) const;
|
||||
|
||||
/// \brief Returns the number of elements in dimension `d`.
|
||||
/// REQUIRES: `0 <= d < dims()`
|
||||
int64 dim_size(int d) const {
|
||||
DCHECK_GE(d, 0);
|
||||
DCHECK_LT(d, dims());
|
||||
return dim_sizes_[d];
|
||||
}
|
||||
|
||||
/// Returns sizes of all dimensions.
|
||||
gtl::ArraySlice<int64> dim_sizes() const { return dim_sizes_; }
|
||||
|
||||
/// Fill `*proto` from `*this`.
|
||||
void AsProto(TensorShapeProto* proto) const;
|
||||
|
||||
// Fill `*tensor_shape` from `*this`.
|
||||
// If `*this` is not fully defined, returns false and
|
||||
// `*tensor_shape` is left in an intermediate state. Otherwise
|
||||
// returns true.
|
||||
bool AsTensorShape(TensorShape* tensor_shape) const;
|
||||
|
||||
/// For error messages.
|
||||
string DebugString() const;
|
||||
static string DebugString(const TensorShapeProto& proto);
|
||||
|
||||
/// \brief Returns a `PartialTensorShape` whose dimensions are
|
||||
/// `dims[0]`, `dims[1]`, ..., `dims[n-1]`. Values of -1 are
|
||||
/// considered "unknown".
|
||||
template <typename T>
|
||||
static Status MakePartialShape(const T* dims, int n, PartialTensorShape* out);
|
||||
|
||||
private:
|
||||
/// Create a tensor shape.
|
||||
PartialTensorShape();
|
||||
|
||||
gtl::InlinedVector<int64, 4> dim_sizes_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
Status PartialTensorShape::MakePartialShape(const T* dims, int n,
|
||||
PartialTensorShape* out) {
|
||||
*out = PartialTensorShape();
|
||||
out->dim_sizes_.reserve(n);
|
||||
for (int i = 0; i < n; ++i) {
|
||||
if (dims[i] >= -1) {
|
||||
out->dim_sizes_.push_back(dims[i]);
|
||||
} else {
|
||||
return errors::InvalidArgument("Dimension ", dims[i], " must be >= -1");
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/// \brief Static helper routines for `PartialTensorShape`. Includes a few
|
||||
/// common predicates on a partially known tensor shape.
|
||||
class PartialTensorShapeUtils {
|
||||
public:
|
||||
static string PartialShapeListString(
|
||||
const gtl::ArraySlice<PartialTensorShape>& shapes) {
|
||||
string result = "[";
|
||||
bool first = true;
|
||||
for (const PartialTensorShape& shape : shapes) {
|
||||
strings::StrAppend(&result, (first ? "" : ", "), shape.DebugString());
|
||||
first = false;
|
||||
}
|
||||
strings::StrAppend(&result, "]");
|
||||
return result;
|
||||
}
|
||||
|
||||
static bool AreCompatible(
|
||||
const gtl::ArraySlice<PartialTensorShape>& shapes0,
|
||||
const gtl::ArraySlice<PartialTensorShape>& shapes1) {
|
||||
if (shapes0.size() == shapes1.size()) {
|
||||
for (int i = 0; i < shapes0.size(); ++i) {
|
||||
if (!shapes0[i].IsCompatibleWith(shapes1[i])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_FRAMEWORK_PARTIAL_TENSOR_SHAPE_H_
|
513
tensorflow/core/framework/tensor.h
Normal file
513
tensorflow/core/framework/tensor.h
Normal file
@ -0,0 +1,513 @@
|
||||
/* Copyright 2015 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_H_
|
||||
#define TENSORFLOW_CORE_FRAMEWORK_TENSOR_H_
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/allocation_description.pb.h"
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_description.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/port.h"
|
||||
#include "tensorflow/core/public/status.h"
|
||||
#include "tensorflow/core/public/tensor_shape.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class TensorBuffer; // Forward declaration.
|
||||
class TensorCApi;
|
||||
|
||||
/// Represents an n-dimensional array of values.
|
||||
class Tensor {
|
||||
public:
|
||||
/// Default Tensor constructor. Creates a 1-dimension, 0-element float tensor.
|
||||
Tensor();
|
||||
|
||||
/// \brief Creates a Tensor of the given `type` and `shape`.
|
||||
///
|
||||
/// The underlying buffer is allocated using a `CPUAllocator`.
|
||||
Tensor(DataType type, const TensorShape& shape);
|
||||
|
||||
/// \brief Creates a tensor with the input `type` and `shape`, using the
|
||||
/// allocator `a` to allocate the underlying buffer.
|
||||
///
|
||||
/// `a` must outlive the lifetime of this Tensor.
|
||||
Tensor(Allocator* a, DataType type, const TensorShape& shape);
|
||||
|
||||
/// \brief Creates a tensor with the input `type` and `shape`, using the
|
||||
/// allocator `a` and the specified "allocation_attr" to allocate the
|
||||
/// underlying buffer.
|
||||
///
|
||||
/// `a` must outlive the lifetime of this Tensor.
|
||||
Tensor(Allocator* a, DataType type, const TensorShape& shape,
|
||||
const AllocationAttributes& allocation_attr);
|
||||
|
||||
/// Creates an uninitialized Tensor of the given data type.
|
||||
explicit Tensor(DataType type);
|
||||
|
||||
Tensor(const Tensor& other); /// Copy constructor.
|
||||
|
||||
~Tensor();
|
||||
|
||||
/// Returns the data type.
|
||||
DataType dtype() const { return type_; }
|
||||
|
||||
/// Returns the shape of the tensor.
|
||||
const TensorShape& shape() const { return shape_; }
|
||||
|
||||
/// \brief Convenience accessor for the tensor shape.
|
||||
///
|
||||
/// For all shape accessors, see comments for relevant methods of
|
||||
/// `TensorShape` in `tensor_shape.h`.
|
||||
int dims() const { return shape().dims(); }
|
||||
|
||||
/// Convenience accessor for the tensor shape.
|
||||
int64 dim_size(int d) const { return shape().dim_size(d); }
|
||||
|
||||
/// Convenience accessor for the tensor shape.
|
||||
int64 NumElements() const { return shape().num_elements(); }
|
||||
|
||||
bool IsSameSize(const Tensor& b) const {
|
||||
return shape().IsSameSize(b.shape());
|
||||
}
|
||||
|
||||
// True iff the two tensors use the same underlying refcounted storage
|
||||
bool SharesBufferWith(const Tensor& b) const;
|
||||
|
||||
// The BufferHash of two tensors are equal when they share the same
|
||||
// underlying refcounted storage
|
||||
size_t BufferHash() const;
|
||||
|
||||
/// Has this Tensor been initialized?
|
||||
bool IsInitialized() const;
|
||||
|
||||
/// Returns the estimated memory usage of this tensor.
|
||||
size_t TotalBytes() const;
|
||||
|
||||
/// Assign operator. This tensor shares other's underlying storage.
|
||||
Tensor& operator=(const Tensor& other) {
|
||||
CopyFromInternal(other, other.shape());
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// \brief Copy the other tensor into this tensor and reshape it.
|
||||
///
|
||||
/// This tensor shares other's underlying storage. Returns `true`
|
||||
/// iff `other.shape()` has the same number of elements of the given
|
||||
/// `shape`.
|
||||
bool CopyFrom(const Tensor& other,
|
||||
const TensorShape& shape) TF_MUST_USE_RESULT {
|
||||
if (other.NumElements() != shape.num_elements()) return false;
|
||||
CopyFromInternal(other, shape);
|
||||
return true;
|
||||
}
|
||||
|
||||
/// \brief Slice this tensor along the 1st dimension.
|
||||
|
||||
/// I.e., the returned tensor satisfies
|
||||
/// returned[i, ...] == this[dim0_start + i, ...].
|
||||
/// The returned tensor shares the underlying tensor buffer with this
|
||||
/// tensor.
|
||||
///
|
||||
/// NOTE: The returned tensor may not satisfies the same alignment
|
||||
/// requirement as this tensor depending on the shape. The caller
|
||||
/// must check the returned tensor's alignment before calling certain
|
||||
/// methods that have alignment requirement (e.g., `flat()`, `tensor()`).
|
||||
///
|
||||
/// REQUIRES: `dims()` >= 1
|
||||
/// REQUIRES: `0 <= dim0_start <= dim0_limit <= dim_size(0)`
|
||||
Tensor Slice(int64 dim0_start, int64 dim0_limit) const;
|
||||
|
||||
/// \brief Parse `other` and construct the tensor.
|
||||
|
||||
/// Returns `true` iff the parsing succeeds. If the parsing fails,
|
||||
/// the state of `*this` is unchanged.
|
||||
bool FromProto(const TensorProto& other) TF_MUST_USE_RESULT;
|
||||
bool FromProto(Allocator* a, const TensorProto& other) TF_MUST_USE_RESULT;
|
||||
|
||||
/// \brief Fills in `proto` with `*this` tensor's content.
|
||||
///
|
||||
/// `AsProtoField()` fills in the repeated field for `proto.dtype()`, while
|
||||
/// `AsProtoTensorContent()` encodes the content in `proto.tensor_content()`
|
||||
/// in a compact form.
|
||||
void AsProtoField(TensorProto* proto) const;
|
||||
void AsProtoTensorContent(TensorProto* proto) const;
|
||||
|
||||
/// \brief Return the tensor data as an `Eigen::Tensor` with the type and
|
||||
/// sizes of this `Tensor`.
|
||||
///
|
||||
/// Use these methods when you know the data type and the number of
|
||||
/// dimensions of the Tensor and you want an `Eigen::Tensor`
|
||||
/// automatically sized to the `Tensor` sizes. The implementation check
|
||||
/// fails if either type or sizes mismatch.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// ```c++
|
||||
///
|
||||
/// typedef float T;
|
||||
/// Tensor my_mat(...built with Shape{rows: 3, cols: 5}...);
|
||||
/// auto mat = my_mat.matrix<T>(); // 2D Eigen::Tensor, 3 x 5.
|
||||
/// auto mat = my_mat.tensor<T, 2>(); // 2D Eigen::Tensor, 3 x 5.
|
||||
/// auto vec = my_mat.vec<T>(); // CHECK fails as my_mat is 2D.
|
||||
/// auto vec = my_mat.tensor<T, 3>(); // CHECK fails as my_mat is 2D.
|
||||
/// auto mat = my_mat.matrix<int32>();// CHECK fails as type mismatch.
|
||||
///
|
||||
/// ```
|
||||
template <typename T>
|
||||
typename TTypes<T>::Vec vec() {
|
||||
return tensor<T, 1>();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
typename TTypes<T>::Matrix matrix() {
|
||||
return tensor<T, 2>();
|
||||
}
|
||||
|
||||
template <typename T, size_t NDIMS>
|
||||
typename TTypes<T, NDIMS>::Tensor tensor();
|
||||
|
||||
/// \brief Return the tensor data as an `Eigen::Tensor` of the data type and a
|
||||
/// specified shape.
|
||||
///
|
||||
/// These methods allow you to access the data with the dimensions
|
||||
/// and sizes of your choice. You do not need to know the number of
|
||||
/// dimensions of the Tensor to call them. However, they `CHECK` that
|
||||
/// the type matches and the dimensions requested creates an
|
||||
/// `Eigen::Tensor` with the same number of elements as the tensor.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// ```c++
|
||||
///
|
||||
/// typedef float T;
|
||||
/// Tensor my_ten(...built with Shape{planes: 4, rows: 3, cols: 5}...);
|
||||
/// // 1D Eigen::Tensor, size 60:
|
||||
/// auto flat = my_ten.flat<T>();
|
||||
/// // 2D Eigen::Tensor 12 x 5:
|
||||
/// auto inner = my_ten.flat_inner_dims<T>();
|
||||
/// // 2D Eigen::Tensor 4 x 15:
|
||||
/// auto outer = my_ten.shaped<T, 2>({4, 15});
|
||||
/// // CHECK fails, bad num elements:
|
||||
/// auto outer = my_ten.shaped<T, 2>({4, 8});
|
||||
/// // 3D Eigen::Tensor 6 x 5 x 2:
|
||||
/// auto weird = my_ten.shaped<T, 3>({6, 5, 2});
|
||||
/// // CHECK fails, type mismatch:
|
||||
/// auto bad = my_ten.flat<int32>();
|
||||
///
|
||||
/// ```
|
||||
template <typename T>
|
||||
typename TTypes<T>::Flat flat() {
|
||||
return shaped<T, 1>({NumElements()});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
typename TTypes<T>::UnalignedFlat unaligned_flat() {
|
||||
return unaligned_shaped<T, 1>({NumElements()});
|
||||
}
|
||||
|
||||
/// Returns the data as an Eigen::Tensor with 2 dimensions, collapsing all
|
||||
/// Tensor dimensions but the last one into the first dimension of the result.
|
||||
template <typename T>
|
||||
typename TTypes<T>::Matrix flat_inner_dims() {
|
||||
int64 last_size = dims() > 0 ? dim_size(dims() - 1) : 1;
|
||||
if (last_size == 0) {
|
||||
DCHECK_EQ(NumElements(), 0);
|
||||
// Return something empty, avoiding divide by 0
|
||||
return shaped<T, 2>({0, 0});
|
||||
} else {
|
||||
return shaped<T, 2>({NumElements() / last_size, last_size});
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the data as an Eigen::Tensor with 2 dimensions, collapsing all
|
||||
/// Tensor dimensions but the first one into the last dimension of the result.
|
||||
template <typename T>
|
||||
typename TTypes<T>::Matrix flat_outer_dims() {
|
||||
int64 first_size = dims() > 0 ? dim_size(0) : 1;
|
||||
if (first_size == 0) {
|
||||
DCHECK_EQ(NumElements(), 0);
|
||||
// Return something empty, avoiding divide by 0
|
||||
return shaped<T, 2>({0, 0});
|
||||
} else {
|
||||
return shaped<T, 2>({first_size, NumElements() / first_size});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, size_t NDIMS>
|
||||
typename TTypes<T, NDIMS>::Tensor shaped(gtl::ArraySlice<int64> new_sizes);
|
||||
|
||||
template <typename T, size_t NDIMS>
|
||||
typename TTypes<T, NDIMS>::UnalignedTensor unaligned_shaped(
|
||||
gtl::ArraySlice<int64> new_sizes);
|
||||
|
||||
/// \brief Return the Tensor data as a `TensorMap` of fixed size 1:
|
||||
/// `TensorMap<TensorFixedSize<T, 1>>`.
|
||||
|
||||
/// Using `scalar()` allows the compiler to perform optimizations as
|
||||
/// the size of the tensor is known at compile time.
|
||||
template <typename T>
|
||||
typename TTypes<T>::Scalar scalar();
|
||||
|
||||
/// Const versions of all the methods above.
|
||||
template <typename T>
|
||||
typename TTypes<T>::ConstVec vec() const {
|
||||
return tensor<T, 1>();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
typename TTypes<T>::ConstMatrix matrix() const {
|
||||
return tensor<T, 2>();
|
||||
}
|
||||
|
||||
template <typename T, size_t NDIMS>
|
||||
typename TTypes<T, NDIMS>::ConstTensor tensor() const;
|
||||
|
||||
template <typename T>
|
||||
typename TTypes<T>::ConstFlat flat() const {
|
||||
return shaped<T, 1>({NumElements()});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
typename TTypes<T>::UnalignedConstFlat unaligned_flat() const {
|
||||
return unaligned_shaped<T, 1>({NumElements()});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
typename TTypes<T>::ConstMatrix flat_inner_dims() const {
|
||||
int64 last_size = dims() > 0 ? dim_size(dims() - 1) : 1;
|
||||
if (last_size == 0) {
|
||||
DCHECK_EQ(NumElements(), 0);
|
||||
// Return something empty, avoiding divide by 0
|
||||
return shaped<T, 2>({0, 0});
|
||||
} else {
|
||||
return shaped<T, 2>({NumElements() / last_size, last_size});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
typename TTypes<T>::ConstMatrix flat_outer_dims() const {
|
||||
int64 first_size = dims() > 0 ? dim_size(0) : 1;
|
||||
if (first_size == 0) {
|
||||
DCHECK_EQ(NumElements(), 0);
|
||||
// Return something empty, avoiding divide by 0
|
||||
return shaped<T, 2>({0, 0});
|
||||
} else {
|
||||
return shaped<T, 2>({first_size, NumElements() / first_size});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, size_t NDIMS>
|
||||
typename TTypes<T, NDIMS>::ConstTensor shaped(
|
||||
gtl::ArraySlice<int64> new_sizes) const;
|
||||
template <typename T, size_t NDIMS>
|
||||
typename TTypes<T, NDIMS>::UnalignedConstTensor unaligned_shaped(
|
||||
gtl::ArraySlice<int64> new_sizes) const;
|
||||
|
||||
template <typename T>
|
||||
typename TTypes<T>::ConstScalar scalar() const;
|
||||
|
||||
/// Render the first `max_entries` values in `*this` into a string.
|
||||
string SummarizeValue(int64 max_entries) const;
|
||||
|
||||
/// A human-readable summary of the tensor suitable for debugging.
|
||||
string DebugString() const;
|
||||
|
||||
/// Fill in the `TensorDescription` proto with metadata about the
|
||||
/// tensor that is useful for monitoring and debugging.
|
||||
void FillDescription(TensorDescription* description) const;
|
||||
|
||||
/// \brief Returns a `StringPiece` mapping the current tensor's buffer.
|
||||
///
|
||||
/// The returned `StringPiece` may point to memory location on devices
|
||||
/// that the CPU cannot address directly.
|
||||
///
|
||||
/// NOTE: The underlying tensor buffer is refcounted, so the lifetime
|
||||
/// of the contents mapped by the `StringPiece` matches the lifetime of
|
||||
/// the buffer; callers should arrange to make sure the buffer does
|
||||
/// not get destroyed while the `StringPiece` is still used.
|
||||
///
|
||||
/// REQUIRES: `DataTypeCanUseMemcpy(dtype())`.
|
||||
StringPiece tensor_data() const;
|
||||
|
||||
private:
|
||||
DataType type_;
|
||||
TensorShape shape_;
|
||||
TensorBuffer* buf_;
|
||||
|
||||
friend class DMAHelper;
|
||||
friend class TensorCApi;
|
||||
friend class TensorReference; // For access to buf_
|
||||
friend class VariableOp; // For access to set_shape
|
||||
friend class AutoReloadVariableOp; // For access to set_shape
|
||||
|
||||
// Creates a tensor with the input datatype, shape and buf.
|
||||
//
|
||||
// Acquires a ref on buf that belongs to this Tensor.
|
||||
Tensor(DataType type, const TensorShape& shape, TensorBuffer* buf);
|
||||
|
||||
bool CanUseDMA() const;
|
||||
|
||||
// Only needed by variable op to set the shape of an uninitialized
|
||||
// Tensor.
|
||||
// TODO: Remove this when we have a better story for detecting
|
||||
// uninitialized tensors.
|
||||
void set_shape(const TensorShape& shape) { shape_ = shape; }
|
||||
|
||||
void CopyFromInternal(const Tensor& other, const TensorShape& shape);
|
||||
|
||||
template <typename T>
|
||||
T* base() const;
|
||||
};
|
||||
|
||||
// Implementation details
|
||||
|
||||
// Interface to access the raw ref-counted data buffer.
|
||||
class TensorBuffer : public core::RefCounted {
|
||||
public:
|
||||
~TensorBuffer() override {}
|
||||
|
||||
// data() points to a memory region of size() bytes.
|
||||
virtual void* data() const = 0;
|
||||
virtual size_t size() const = 0;
|
||||
|
||||
// If this TensorBuffer is sub-buffer of another TensorBuffer,
|
||||
// returns that TensorBuffer. Otherwise, returns this.
|
||||
virtual TensorBuffer* root_buffer() = 0;
|
||||
|
||||
// Fill metadata about the allocation into the proto.
|
||||
virtual void FillAllocationDescription(
|
||||
AllocationDescription* proto) const = 0;
|
||||
|
||||
template <typename T>
|
||||
T* base() const {
|
||||
return reinterpret_cast<T*>(data());
|
||||
}
|
||||
};
|
||||
|
||||
inline void CheckEigenAlignment(const void* ptr) {
|
||||
#if EIGEN_ALIGN == 1
|
||||
CHECK_EQ(reinterpret_cast<intptr_t>(ptr) % EIGEN_ALIGN_BYTES, 0);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T* Tensor::base() const {
|
||||
return buf_ == nullptr ? nullptr : buf_->base<T>();
|
||||
}
|
||||
|
||||
template <typename T, size_t NDIMS>
|
||||
typename TTypes<T, NDIMS>::Tensor Tensor::tensor() {
|
||||
CHECK_EQ(dtype(), DataTypeToEnum<T>::v());
|
||||
CheckEigenAlignment(base<T>());
|
||||
return typename TTypes<T, NDIMS>::Tensor(base<T>(),
|
||||
shape().AsEigenDSizes<NDIMS>());
|
||||
}
|
||||
|
||||
template <typename T, size_t NDIMS>
|
||||
typename TTypes<T, NDIMS>::ConstTensor Tensor::tensor() const {
|
||||
CheckEigenAlignment(base<T>());
|
||||
CHECK_EQ(dtype(), DataTypeToEnum<T>::v());
|
||||
return typename TTypes<T, NDIMS>::ConstTensor(base<const T>(),
|
||||
shape().AsEigenDSizes<NDIMS>());
|
||||
}
|
||||
|
||||
template <typename T, size_t NDIMS>
|
||||
typename TTypes<T, NDIMS>::Tensor Tensor::shaped(
|
||||
gtl::ArraySlice<int64> new_sizes) {
|
||||
CheckEigenAlignment(base<T>());
|
||||
CHECK_EQ(dtype(), DataTypeToEnum<T>::v());
|
||||
CHECK_EQ(NDIMS, new_sizes.size());
|
||||
int64 new_num_elements = 1;
|
||||
Eigen::array<Eigen::DenseIndex, NDIMS> dims;
|
||||
for (int d = 0; d < NDIMS; d++) {
|
||||
new_num_elements *= new_sizes[d];
|
||||
dims[d] = new_sizes[d];
|
||||
}
|
||||
CHECK_EQ(new_num_elements, NumElements());
|
||||
return typename TTypes<T, NDIMS>::Tensor(base<T>(), dims);
|
||||
}
|
||||
|
||||
template <typename T, size_t NDIMS>
|
||||
typename TTypes<T, NDIMS>::UnalignedTensor Tensor::unaligned_shaped(
|
||||
gtl::ArraySlice<int64> new_sizes) {
|
||||
CHECK_EQ(dtype(), DataTypeToEnum<T>::v());
|
||||
CHECK_EQ(NDIMS, new_sizes.size());
|
||||
int64 new_num_elements = 1;
|
||||
Eigen::array<Eigen::DenseIndex, NDIMS> dims;
|
||||
for (int d = 0; d < NDIMS; d++) {
|
||||
new_num_elements *= new_sizes[d];
|
||||
dims[d] = new_sizes[d];
|
||||
}
|
||||
CHECK_EQ(new_num_elements, NumElements());
|
||||
return typename TTypes<T, NDIMS>::UnalignedTensor(base<T>(), dims);
|
||||
}
|
||||
|
||||
template <typename T, size_t NDIMS>
|
||||
typename TTypes<T, NDIMS>::ConstTensor Tensor::shaped(
|
||||
gtl::ArraySlice<int64> new_sizes) const {
|
||||
CheckEigenAlignment(base<T>());
|
||||
CHECK_EQ(dtype(), DataTypeToEnum<T>::v());
|
||||
CHECK_EQ(NDIMS, new_sizes.size());
|
||||
int64 new_num_elements = 1;
|
||||
Eigen::array<Eigen::DenseIndex, NDIMS> dims;
|
||||
for (int d = 0; d < NDIMS; d++) {
|
||||
new_num_elements *= new_sizes[d];
|
||||
dims[d] = new_sizes[d];
|
||||
}
|
||||
CHECK_EQ(new_num_elements, NumElements());
|
||||
return typename TTypes<T, NDIMS>::ConstTensor(base<T>(), dims);
|
||||
}
|
||||
|
||||
template <typename T, size_t NDIMS>
|
||||
typename TTypes<T, NDIMS>::UnalignedConstTensor Tensor::unaligned_shaped(
|
||||
gtl::ArraySlice<int64> new_sizes) const {
|
||||
CHECK_EQ(dtype(), DataTypeToEnum<T>::v());
|
||||
CHECK_EQ(NDIMS, new_sizes.size());
|
||||
int64 new_num_elements = 1;
|
||||
Eigen::array<Eigen::DenseIndex, NDIMS> dims;
|
||||
for (int d = 0; d < NDIMS; d++) {
|
||||
new_num_elements *= new_sizes[d];
|
||||
dims[d] = new_sizes[d];
|
||||
}
|
||||
CHECK_EQ(new_num_elements, NumElements());
|
||||
return typename TTypes<T, NDIMS>::UnalignedConstTensor(base<T>(), dims);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
typename TTypes<T>::Scalar Tensor::scalar() {
|
||||
CheckEigenAlignment(base<T>());
|
||||
CHECK_EQ(1, NumElements()) << "Must have a one element tensor";
|
||||
return typename TTypes<T>::Scalar(base<T>());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
typename TTypes<T>::ConstScalar Tensor::scalar() const {
|
||||
CheckEigenAlignment(base<T>());
|
||||
CHECK_EQ(1, NumElements()) << "Must have a one element tensor";
|
||||
return typename TTypes<T>::ConstScalar(base<T>());
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_FRAMEWORK_TENSOR_H_
|
248
tensorflow/core/framework/tensor_shape.h
Normal file
248
tensorflow/core/framework/tensor_shape.h
Normal file
@ -0,0 +1,248 @@
|
||||
/* Copyright 2015 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_H_
|
||||
#define TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class TensorShapeIter; // Declared below
|
||||
|
||||
/// Manages the dimensions of a Tensor and their sizes.
|
||||
class TensorShape {
|
||||
public:
|
||||
/// \brief Construct a `TensorShape` from the provided sizes.
|
||||
/// REQUIRES: `dim_sizes[i] >= 0`
|
||||
explicit TensorShape(gtl::ArraySlice<int64> dim_sizes);
|
||||
TensorShape(std::initializer_list<int64> dim_sizes)
|
||||
: TensorShape(gtl::ArraySlice<int64>(dim_sizes)) {}
|
||||
|
||||
/// REQUIRES: `IsValid(proto)`
|
||||
explicit TensorShape(const TensorShapeProto& proto);
|
||||
|
||||
/// Create a tensor shape with no dimensions and one element, which you can
|
||||
/// then call `AddDim()` on.
|
||||
TensorShape();
|
||||
|
||||
/// Returns `true` iff `proto` is a valid tensor shape.
|
||||
static bool IsValid(const TensorShapeProto& proto);
|
||||
|
||||
/// Returns `OK` iff `proto` is a valid tensor shape, and a descriptive error
|
||||
/// status otherwise.
|
||||
static Status IsValidShape(const TensorShapeProto& proto);
|
||||
|
||||
/// Clear a tensor shape
|
||||
void Clear();
|
||||
|
||||
/// \brief Add a dimension to the end ("inner-most").
|
||||
/// REQUIRES: `size >= 0`
|
||||
void AddDim(int64 size);
|
||||
|
||||
/// Appends all the dimensions from `shape`.
|
||||
void AppendShape(const TensorShape& shape);
|
||||
|
||||
/// \brief Insert a dimension somewhere in the `TensorShape`.
|
||||
/// REQUIRES: `0 <= d <= dims()`
|
||||
/// REQUIRES: `size >= 0`
|
||||
void InsertDim(int d, int64 size);
|
||||
|
||||
/// \brief Modifies the size of the dimension `d` to be `size`
|
||||
/// REQUIRES: `0 <= d < dims()`
|
||||
/// REQUIRES: `size >= 0`
|
||||
void set_dim(int d, int64 size);
|
||||
|
||||
/// \brief Removes dimension `d` from the `TensorShape`.
|
||||
/// REQUIRES: `0 <= d < dims()`
|
||||
void RemoveDim(int d);
|
||||
|
||||
/// Return the number of dimensions in the tensor.
|
||||
int dims() const { return dim_sizes_.size(); }
|
||||
|
||||
/// \brief Returns the number of elements in dimension `d`.
|
||||
/// REQUIRES: `0 <= d < dims()`
|
||||
// TODO(touts): Rename to `dimension()` to match
|
||||
// `Eigen::Tensor::dimension()`?
|
||||
int64 dim_size(int d) const {
|
||||
DCHECK_GE(d, 0);
|
||||
DCHECK_LT(d, dims());
|
||||
return dim_sizes_[d];
|
||||
}
|
||||
|
||||
/// Returns sizes of all dimensions.
|
||||
gtl::ArraySlice<int64> dim_sizes() const { return dim_sizes_; }
|
||||
|
||||
/// \brief Returns the number of elements in the tensor.
|
||||
///
|
||||
/// We use `int64` and not `size_t` to be compatible with `Eigen::Tensor`
|
||||
/// which uses `ptrdiff_t`.
|
||||
int64 num_elements() const { return num_elements_; }
|
||||
|
||||
/// Returns true if `*this` and `b` have the same sizes. Ignores
|
||||
/// dimension names.
|
||||
bool IsSameSize(const TensorShape& b) const;
|
||||
bool operator==(const TensorShape& b) const { return IsSameSize(b); }
|
||||
|
||||
/// Fill `*proto` from `*this`.
|
||||
void AsProto(TensorShapeProto* proto) const;
|
||||
|
||||
/// Fill `*dsizes` from `*this`.
|
||||
template <int NDIMS>
|
||||
Eigen::DSizes<Eigen::DenseIndex, NDIMS> AsEigenDSizes() const;
|
||||
|
||||
/// Same as `AsEigenDSizes()` but allows for `NDIMS > dims()` -- in
|
||||
/// which case we pad the rest of the sizes with 1.
|
||||
template <int NDIMS>
|
||||
Eigen::DSizes<Eigen::DenseIndex, NDIMS> AsEigenDSizesWithPadding() const;
|
||||
|
||||
/// For iterating through the dimensions.
|
||||
TensorShapeIter begin() const;
|
||||
TensorShapeIter end() const;
|
||||
|
||||
/// For error messages.
|
||||
string DebugString() const;
|
||||
|
||||
/// Same as DebugString()
|
||||
string ShortDebugString() const { return DebugString(); }
|
||||
// TODO(irving): Remove, used to be different but isn't now.
|
||||
|
||||
/// Same as `TensorShape(proto).ShortDebugString()` but doesn't crash for
|
||||
/// invalid protos.
|
||||
static string ShortDebugString(const TensorShapeProto& proto);
|
||||
// TODO(irving): Rename to DebugString.
|
||||
|
||||
private:
|
||||
// Recalculates the dimensions of this tensor after they are modified.
|
||||
void recompute_dims();
|
||||
|
||||
// TODO(josh11b): Maybe use something from the Eigen Tensor library
|
||||
// for the sizes.
|
||||
gtl::InlinedVector<int64, 4> dim_sizes_;
|
||||
|
||||
// total number of elements (avoids recomputing it each time).
|
||||
int64 num_elements_;
|
||||
};
|
||||
|
||||
struct TensorShapeDim {
|
||||
explicit TensorShapeDim(int64 s) : size(s) {}
|
||||
int size;
|
||||
};
|
||||
|
||||
class TensorShapeIter {
|
||||
public:
|
||||
TensorShapeIter(const TensorShape* shape, int d) : shape_(shape), d_(d) {}
|
||||
bool operator==(const TensorShapeIter& rhs) {
|
||||
DCHECK(shape_ == rhs.shape_);
|
||||
return d_ == rhs.d_;
|
||||
}
|
||||
bool operator!=(const TensorShapeIter& rhs) {
|
||||
DCHECK(shape_ == rhs.shape_);
|
||||
return d_ != rhs.d_;
|
||||
}
|
||||
void operator++() { ++d_; }
|
||||
TensorShapeDim operator*() { return TensorShapeDim(shape_->dim_size(d_)); }
|
||||
|
||||
private:
|
||||
const TensorShape* shape_;
|
||||
int d_;
|
||||
};
|
||||
|
||||
/// \brief Static helper routines for `TensorShape`. Includes a few common
|
||||
/// predicates on a tensor shape.
|
||||
class TensorShapeUtils {
|
||||
public:
|
||||
static bool IsScalar(const TensorShape& shape) { return shape.dims() == 0; }
|
||||
|
||||
static bool IsVector(const TensorShape& shape) { return shape.dims() == 1; }
|
||||
|
||||
static bool IsVectorOrHigher(const TensorShape& shape) {
|
||||
return shape.dims() >= 1;
|
||||
}
|
||||
|
||||
static bool IsMatrix(const TensorShape& shape) { return shape.dims() == 2; }
|
||||
|
||||
static bool IsMatrixOrHigher(const TensorShape& shape) {
|
||||
return shape.dims() >= 2;
|
||||
}
|
||||
|
||||
/// \brief Returns a `TensorShape` whose dimensions are
|
||||
/// `dims[0]`, `dims[1]`, ..., `dims[n-1]`.
|
||||
template <typename T>
|
||||
static Status MakeShape(const T* dims, int n, TensorShape* out) {
|
||||
*out = TensorShape();
|
||||
for (int i = 0; i < n; ++i) {
|
||||
if (dims[i] >= 0) {
|
||||
out->AddDim(dims[i]);
|
||||
} else {
|
||||
return errors::InvalidArgument("Dimension ", dims[i], " must be >= 0");
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
static string ShapeListString(const gtl::ArraySlice<TensorShape>& shapes) {
|
||||
string result = "[";
|
||||
bool first = true;
|
||||
for (const TensorShape& shape : shapes) {
|
||||
strings::StrAppend(&result, (first ? "" : ", "), shape.DebugString());
|
||||
first = false;
|
||||
}
|
||||
strings::StrAppend(&result, "]");
|
||||
return result;
|
||||
}
|
||||
|
||||
static bool StartsWith(const TensorShape& shape0, const TensorShape& shape1);
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Template method implementation details below
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
template <int NDIMS>
|
||||
Eigen::DSizes<Eigen::DenseIndex, NDIMS> TensorShape::AsEigenDSizes() const {
|
||||
CHECK_EQ(NDIMS, dims()) << "Asking for tensor of " << NDIMS
|
||||
<< " for a tensor of " << dims() << " dimensions";
|
||||
return AsEigenDSizesWithPadding<NDIMS>();
|
||||
}
|
||||
|
||||
template <int NDIMS>
|
||||
Eigen::DSizes<Eigen::DenseIndex, NDIMS> TensorShape::AsEigenDSizesWithPadding()
|
||||
const {
|
||||
CHECK_GE(NDIMS, dims()) << "Asking for tensor of " << NDIMS
|
||||
<< " for a tensor of " << dims() << " dimensions";
|
||||
Eigen::DSizes<Eigen::DenseIndex, NDIMS> dsizes;
|
||||
for (int d = 0; d < dims(); d++) {
|
||||
dsizes[d] = dim_size(d);
|
||||
}
|
||||
for (int d = dims(); d < NDIMS; d++) {
|
||||
dsizes[d] = 1;
|
||||
}
|
||||
return dsizes;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_H_
|
112
tensorflow/core/lib/core/status.h
Normal file
112
tensorflow/core/lib/core/status.h
Normal file
@ -0,0 +1,112 @@
|
||||
/* Copyright 2015 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_LIB_CORE_STATUS_H_
|
||||
#define TENSORFLOW_CORE_LIB_CORE_STATUS_H_
|
||||
|
||||
#include <functional>
|
||||
#include <iosfwd>
|
||||
#include <string>
|
||||
#include "tensorflow/core/lib/core/error_codes.pb.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class Status {
|
||||
public:
|
||||
/// Create a success status.
|
||||
Status() : state_(NULL) {}
|
||||
~Status() { delete state_; }
|
||||
|
||||
/// \brief Create a status with the specified error code and msg as a
|
||||
/// human-readable string containing more detailed information.
|
||||
Status(tensorflow::error::Code code, tensorflow::StringPiece msg);
|
||||
|
||||
/// Copy the specified status.
|
||||
Status(const Status& s);
|
||||
void operator=(const Status& s);
|
||||
|
||||
static Status OK() { return Status(); }
|
||||
|
||||
/// Returns true iff the status indicates success.
|
||||
bool ok() const { return (state_ == NULL); }
|
||||
|
||||
tensorflow::error::Code code() const {
|
||||
return ok() ? tensorflow::error::OK : state_->code;
|
||||
}
|
||||
|
||||
const string& error_message() const {
|
||||
return ok() ? empty_string() : state_->msg;
|
||||
}
|
||||
|
||||
bool operator==(const Status& x) const;
|
||||
bool operator!=(const Status& x) const;
|
||||
|
||||
/// \brief If `ok()`, stores `new_status` into `*this`. If `!ok()`,
|
||||
/// preserves the current status, but may augment with additional
|
||||
/// information about `new_status`.
|
||||
///
|
||||
/// Convenient way of keeping track of the first error encountered.
|
||||
/// Instead of:
|
||||
/// `if (overall_status.ok()) overall_status = new_status`
|
||||
/// Use:
|
||||
/// `overall_status.Update(new_status);`
|
||||
void Update(const Status& new_status);
|
||||
|
||||
/// \brief Return a string representation of this status suitable for
|
||||
/// printing. Returns the string `"OK"` for success.
|
||||
string ToString() const;
|
||||
|
||||
private:
|
||||
static const string& empty_string();
|
||||
struct State {
|
||||
tensorflow::error::Code code;
|
||||
string msg;
|
||||
};
|
||||
// OK status has a `NULL` state_. Otherwise, `state_` points to
|
||||
// a `State` structure containing the error code and message(s)
|
||||
State* state_;
|
||||
|
||||
void SlowCopyFrom(const State* src);
|
||||
};
|
||||
|
||||
inline Status::Status(const Status& s)
|
||||
: state_((s.state_ == NULL) ? NULL : new State(*s.state_)) {}
|
||||
|
||||
inline void Status::operator=(const Status& s) {
|
||||
// The following condition catches both aliasing (when this == &s),
|
||||
// and the common case where both s and *this are ok.
|
||||
if (state_ != s.state_) {
|
||||
SlowCopyFrom(s.state_);
|
||||
}
|
||||
}
|
||||
|
||||
inline bool Status::operator==(const Status& x) const {
|
||||
return (this->state_ == x.state_) || (ToString() == x.ToString());
|
||||
}
|
||||
|
||||
inline bool Status::operator!=(const Status& x) const { return !(*this == x); }
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const Status& x);
|
||||
|
||||
typedef std::function<void(const Status&)> StatusCallback;
|
||||
|
||||
#define TF_CHECK_OK(val) CHECK_EQ(::tensorflow::Status::OK(), (val))
|
||||
#define TF_QCHECK_OK(val) QCHECK_EQ(::tensorflow::Status::OK(), (val))
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_LIB_CORE_STATUS_H_
|
327
tensorflow/core/platform/env.h
Normal file
327
tensorflow/core/platform/env.h
Normal file
@ -0,0 +1,327 @@
|
||||
/* Copyright 2015 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_PLATFORM_ENV_H_
|
||||
#define TENSORFLOW_CORE_PLATFORM_ENV_H_
|
||||
|
||||
#include <stdint.h>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/port.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class RandomAccessFile;
|
||||
class Thread;
|
||||
class WritableFile;
|
||||
struct ThreadOptions;
|
||||
|
||||
/// \brief An interface used by the tensorflow implementation to
|
||||
/// access operating system functionality like the filesystem etc.
|
||||
///
|
||||
/// Callers may wish to provide a custom Env object to get fine grain
|
||||
/// control.
|
||||
///
|
||||
/// All Env implementations are safe for concurrent access from
|
||||
/// multiple threads without any external synchronization.
|
||||
class Env {
|
||||
public:
|
||||
Env() {}
|
||||
virtual ~Env();
|
||||
|
||||
/// \brief Returns a default environment suitable for the current operating
|
||||
/// system.
|
||||
///
|
||||
/// Sophisticated users may wish to provide their own Env
|
||||
/// implementation instead of relying on this default environment.
|
||||
///
|
||||
/// The result of Default() belongs to this library and must never be deleted.
|
||||
static Env* Default();
|
||||
|
||||
/// \brief Creates a brand new random access read-only file with the
|
||||
/// specified name.
|
||||
|
||||
/// On success, stores a pointer to the new file in
|
||||
/// *result and returns OK. On failure stores NULL in *result and
|
||||
/// returns non-OK. If the file does not exist, returns a non-OK
|
||||
/// status.
|
||||
///
|
||||
/// The returned file may be concurrently accessed by multiple threads.
|
||||
virtual Status NewRandomAccessFile(const string& fname,
|
||||
RandomAccessFile** result) = 0;
|
||||
|
||||
/// \brief Creates an object that writes to a new file with the specified
|
||||
/// name.
|
||||
///
|
||||
/// Deletes any existing file with the same name and creates a
|
||||
/// new file. On success, stores a pointer to the new file in
|
||||
/// *result and returns OK. On failure stores NULL in *result and
|
||||
/// returns non-OK.
|
||||
///
|
||||
/// The returned file will only be accessed by one thread at a time.
|
||||
virtual Status NewWritableFile(const string& fname,
|
||||
WritableFile** result) = 0;
|
||||
|
||||
/// \brief Creates an object that either appends to an existing file, or
|
||||
/// writes to a new file (if the file does not exist to begin with).
|
||||
///
|
||||
/// On success, stores a pointer to the new file in *result and
|
||||
/// returns OK. On failure stores NULL in *result and returns
|
||||
/// non-OK.
|
||||
///
|
||||
/// The returned file will only be accessed by one thread at a time.
|
||||
virtual Status NewAppendableFile(const string& fname,
|
||||
WritableFile** result) = 0;
|
||||
|
||||
/// Returns true iff the named file exists.
|
||||
virtual bool FileExists(const string& fname) = 0;
|
||||
|
||||
/// \brief Stores in *result the names of the children of the specified
|
||||
/// directory. The names are relative to "dir".
|
||||
///
|
||||
/// Original contents of *results are dropped.
|
||||
virtual Status GetChildren(const string& dir,
|
||||
std::vector<string>* result) = 0;
|
||||
|
||||
/// Deletes the named file.
|
||||
virtual Status DeleteFile(const string& fname) = 0;
|
||||
|
||||
/// Creates the specified directory.
|
||||
virtual Status CreateDir(const string& dirname) = 0;
|
||||
|
||||
/// Deletes the specified directory.
|
||||
virtual Status DeleteDir(const string& dirname) = 0;
|
||||
|
||||
/// Stores the size of `fname` in `*file_size`.
|
||||
virtual Status GetFileSize(const string& fname, uint64* file_size) = 0;
|
||||
|
||||
/// \brief Renames file src to target. If target already exists, it will be
|
||||
/// replaced.
|
||||
virtual Status RenameFile(const string& src, const string& target) = 0;
|
||||
|
||||
// TODO(jeff,sanjay): Add back thread/thread-pool support if needed.
|
||||
// TODO(jeff,sanjay): if needed, tighten spec so relative to epoch, or
|
||||
// provide a routine to get the absolute time.
|
||||
|
||||
/// \brief Returns the number of micro-seconds since some fixed point in
|
||||
/// time. Only useful for computing deltas of time.
|
||||
virtual uint64 NowMicros() = 0;
|
||||
|
||||
/// Sleeps/delays the thread for the prescribed number of micro-seconds.
|
||||
virtual void SleepForMicroseconds(int micros) = 0;
|
||||
|
||||
/// \brief Returns a new thread that is running fn() and is identified
|
||||
/// (for debugging/performance-analysis) by "name".
|
||||
///
|
||||
/// Caller takes ownership of the result and must delete it eventually
|
||||
/// (the deletion will block until fn() stops running).
|
||||
virtual Thread* StartThread(const ThreadOptions& thread_options,
|
||||
const string& name,
|
||||
std::function<void()> fn) TF_MUST_USE_RESULT = 0;
|
||||
|
||||
// \brief Schedules the given closure on a thread-pool.
|
||||
//
|
||||
// NOTE(mrry): This closure must not block.
|
||||
virtual void SchedClosure(std::function<void()> closure) = 0;
|
||||
|
||||
// \brief Schedules the given closure on a thread-pool after the given number
|
||||
// of microseconds.
|
||||
//
|
||||
// NOTE(mrry): This closure must not block.
|
||||
virtual void SchedClosureAfter(int micros, std::function<void()> closure) = 0;
|
||||
|
||||
// \brief Load a dynamic library.
|
||||
//
|
||||
// Pass "library_filename" to a platform-specific mechanism for dynamically
|
||||
// loading a library. The rules for determining the exact location of the
|
||||
// library are platform-specific and are not documented here.
|
||||
//
|
||||
// On success, returns a handle to the library in "*handle" and returns
|
||||
// OK from the function.
|
||||
// Otherwise returns nullptr in "*handle" and an error status from the
|
||||
// function.
|
||||
virtual Status LoadLibrary(const char* library_filename, void** handle) = 0;
|
||||
|
||||
// \brief Get a pointer to a symbol from a dynamic library.
|
||||
//
|
||||
// "handle" should be a pointer returned from a previous call to LoadLibrary.
|
||||
// On success, store a pointer to the located symbol in "*symbol" and return
|
||||
// OK from the function. Otherwise, returns nullptr in "*symbol" and an error
|
||||
// status from the function.
|
||||
virtual Status GetSymbolFromLibrary(void* handle, const char* symbol_name,
|
||||
void** symbol) = 0;
|
||||
|
||||
private:
|
||||
/// No copying allowed
|
||||
Env(const Env&);
|
||||
void operator=(const Env&);
|
||||
};
|
||||
|
||||
/// A file abstraction for randomly reading the contents of a file.
|
||||
class RandomAccessFile {
|
||||
public:
|
||||
RandomAccessFile() {}
|
||||
virtual ~RandomAccessFile();
|
||||
|
||||
/// \brief Reads up to `n` bytes from the file starting at `offset`.
|
||||
///
|
||||
/// `scratch[0..n-1]` may be written by this routine. Sets `*result`
|
||||
/// to the data that was read (including if fewer than `n` bytes were
|
||||
/// successfully read). May set `*result` to point at data in
|
||||
/// `scratch[0..n-1]`, so `scratch[0..n-1]` must be live when
|
||||
/// `*result` is used.
|
||||
///
|
||||
/// On OK returned status: `n` bytes have been stored in `*result`.
|
||||
/// On non-OK returned status: `[0..n]` bytes have been stored in `*result`.
|
||||
///
|
||||
/// Returns `OUT_OF_RANGE` if fewer than n bytes were stored in `*result`
|
||||
/// because of EOF.
|
||||
///
|
||||
/// Safe for concurrent use by multiple threads.
|
||||
virtual Status Read(uint64 offset, size_t n, StringPiece* result,
|
||||
char* scratch) const = 0;
|
||||
|
||||
private:
|
||||
/// No copying allowed
|
||||
RandomAccessFile(const RandomAccessFile&);
|
||||
void operator=(const RandomAccessFile&);
|
||||
};
|
||||
|
||||
/// \brief A file abstraction for sequential writing.
|
||||
///
|
||||
/// The implementation must provide buffering since callers may append
|
||||
/// small fragments at a time to the file.
|
||||
class WritableFile {
|
||||
public:
|
||||
WritableFile() {}
|
||||
virtual ~WritableFile();
|
||||
|
||||
virtual Status Append(const StringPiece& data) = 0;
|
||||
virtual Status Close() = 0;
|
||||
virtual Status Flush() = 0;
|
||||
virtual Status Sync() = 0;
|
||||
|
||||
private:
|
||||
/// No copying allowed
|
||||
WritableFile(const WritableFile&);
|
||||
void operator=(const WritableFile&);
|
||||
};
|
||||
|
||||
/// \brief An implementation of Env that forwards all calls to another Env.
|
||||
///
|
||||
/// May be useful to clients who wish to override just part of the
|
||||
/// functionality of another Env.
|
||||
class EnvWrapper : public Env {
|
||||
public:
|
||||
/// Initializes an EnvWrapper that delegates all calls to *t
|
||||
explicit EnvWrapper(Env* t) : target_(t) {}
|
||||
virtual ~EnvWrapper();
|
||||
|
||||
/// Returns the target to which this Env forwards all calls
|
||||
Env* target() const { return target_; }
|
||||
|
||||
// The following text is boilerplate that forwards all methods to target()
|
||||
Status NewRandomAccessFile(const string& f, RandomAccessFile** r) override {
|
||||
return target_->NewRandomAccessFile(f, r);
|
||||
}
|
||||
Status NewWritableFile(const string& f, WritableFile** r) override {
|
||||
return target_->NewWritableFile(f, r);
|
||||
}
|
||||
Status NewAppendableFile(const string& f, WritableFile** r) override {
|
||||
return target_->NewAppendableFile(f, r);
|
||||
}
|
||||
bool FileExists(const string& f) override { return target_->FileExists(f); }
|
||||
Status GetChildren(const string& dir, std::vector<string>* r) override {
|
||||
return target_->GetChildren(dir, r);
|
||||
}
|
||||
Status DeleteFile(const string& f) override { return target_->DeleteFile(f); }
|
||||
Status CreateDir(const string& d) override { return target_->CreateDir(d); }
|
||||
Status DeleteDir(const string& d) override { return target_->DeleteDir(d); }
|
||||
Status GetFileSize(const string& f, uint64* s) override {
|
||||
return target_->GetFileSize(f, s);
|
||||
}
|
||||
Status RenameFile(const string& s, const string& t) override {
|
||||
return target_->RenameFile(s, t);
|
||||
}
|
||||
uint64 NowMicros() override { return target_->NowMicros(); }
|
||||
void SleepForMicroseconds(int micros) override {
|
||||
target_->SleepForMicroseconds(micros);
|
||||
}
|
||||
Thread* StartThread(const ThreadOptions& thread_options, const string& name,
|
||||
std::function<void()> fn) override {
|
||||
return target_->StartThread(thread_options, name, fn);
|
||||
}
|
||||
void SchedClosure(std::function<void()> closure) override {
|
||||
target_->SchedClosure(closure);
|
||||
}
|
||||
void SchedClosureAfter(int micros, std::function<void()> closure) override {
|
||||
target_->SchedClosureAfter(micros, closure);
|
||||
}
|
||||
Status LoadLibrary(const char* library_filename, void** handle) override {
|
||||
return target_->LoadLibrary(library_filename, handle);
|
||||
}
|
||||
Status GetSymbolFromLibrary(void* handle, const char* symbol_name,
|
||||
void** symbol) override {
|
||||
return target_->GetSymbolFromLibrary(handle, symbol_name, symbol);
|
||||
}
|
||||
|
||||
private:
|
||||
Env* target_;
|
||||
};
|
||||
|
||||
class Thread {
|
||||
public:
|
||||
Thread() {}
|
||||
|
||||
/// Blocks until the thread of control stops running.
|
||||
virtual ~Thread();
|
||||
|
||||
private:
|
||||
/// No copying allowed
|
||||
Thread(const Thread&);
|
||||
void operator=(const Thread&);
|
||||
};
|
||||
|
||||
/// \brief Options to configure a Thread.
|
||||
///
|
||||
/// Note that the options are all hints, and the
|
||||
/// underlying implementation may choose to ignore it.
|
||||
struct ThreadOptions {
|
||||
/// Thread stack size to use (in bytes).
|
||||
size_t stack_size = 0; // 0: use system default value
|
||||
/// Guard area size to use near thread stacks to use (in bytes)
|
||||
size_t guard_size = 0; // 0: use system default value
|
||||
};
|
||||
|
||||
/// A utility routine: reads contents of named file into `*data`
|
||||
Status ReadFileToString(Env* env, const string& fname, string* data);
|
||||
|
||||
/// A utility routine: write contents of `data` to file named `fname`
|
||||
/// (overwriting existing contents, if any).
|
||||
Status WriteStringToFile(Env* env, const string& fname,
|
||||
const StringPiece& data);
|
||||
|
||||
/// Reads contents of named file and parse as binary encoded proto data
|
||||
/// and store into `*proto`.
|
||||
Status ReadBinaryProto(Env* env, const string& fname,
|
||||
::tensorflow::protobuf::MessageLite* proto);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_PLATFORM_ENV_H_
|
@ -16,312 +16,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_PUBLIC_ENV_H_
|
||||
#define TENSORFLOW_PUBLIC_ENV_H_
|
||||
|
||||
#include <stdint.h>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/port.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/public/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class RandomAccessFile;
|
||||
class Thread;
|
||||
class WritableFile;
|
||||
struct ThreadOptions;
|
||||
|
||||
/// \brief An interface used by the tensorflow implementation to
|
||||
/// access operating system functionality like the filesystem etc.
|
||||
///
|
||||
/// Callers may wish to provide a custom Env object to get fine grain
|
||||
/// control.
|
||||
///
|
||||
/// All Env implementations are safe for concurrent access from
|
||||
/// multiple threads without any external synchronization.
|
||||
class Env {
|
||||
public:
|
||||
Env() {}
|
||||
virtual ~Env();
|
||||
|
||||
/// \brief Returns a default environment suitable for the current operating
|
||||
/// system.
|
||||
///
|
||||
/// Sophisticated users may wish to provide their own Env
|
||||
/// implementation instead of relying on this default environment.
|
||||
///
|
||||
/// The result of Default() belongs to this library and must never be deleted.
|
||||
static Env* Default();
|
||||
|
||||
/// \brief Creates a brand new random access read-only file with the
|
||||
/// specified name.
|
||||
|
||||
/// On success, stores a pointer to the new file in
|
||||
/// *result and returns OK. On failure stores NULL in *result and
|
||||
/// returns non-OK. If the file does not exist, returns a non-OK
|
||||
/// status.
|
||||
///
|
||||
/// The returned file may be concurrently accessed by multiple threads.
|
||||
virtual Status NewRandomAccessFile(const string& fname,
|
||||
RandomAccessFile** result) = 0;
|
||||
|
||||
/// \brief Creates an object that writes to a new file with the specified
|
||||
/// name.
|
||||
///
|
||||
/// Deletes any existing file with the same name and creates a
|
||||
/// new file. On success, stores a pointer to the new file in
|
||||
/// *result and returns OK. On failure stores NULL in *result and
|
||||
/// returns non-OK.
|
||||
///
|
||||
/// The returned file will only be accessed by one thread at a time.
|
||||
virtual Status NewWritableFile(const string& fname,
|
||||
WritableFile** result) = 0;
|
||||
|
||||
/// \brief Creates an object that either appends to an existing file, or
|
||||
/// writes to a new file (if the file does not exist to begin with).
|
||||
///
|
||||
/// On success, stores a pointer to the new file in *result and
|
||||
/// returns OK. On failure stores NULL in *result and returns
|
||||
/// non-OK.
|
||||
///
|
||||
/// The returned file will only be accessed by one thread at a time.
|
||||
virtual Status NewAppendableFile(const string& fname,
|
||||
WritableFile** result) = 0;
|
||||
|
||||
/// Returns true iff the named file exists.
|
||||
virtual bool FileExists(const string& fname) = 0;
|
||||
|
||||
/// \brief Stores in *result the names of the children of the specified
|
||||
/// directory. The names are relative to "dir".
|
||||
///
|
||||
/// Original contents of *results are dropped.
|
||||
virtual Status GetChildren(const string& dir,
|
||||
std::vector<string>* result) = 0;
|
||||
|
||||
/// Deletes the named file.
|
||||
virtual Status DeleteFile(const string& fname) = 0;
|
||||
|
||||
/// Creates the specified directory.
|
||||
virtual Status CreateDir(const string& dirname) = 0;
|
||||
|
||||
/// Deletes the specified directory.
|
||||
virtual Status DeleteDir(const string& dirname) = 0;
|
||||
|
||||
/// Stores the size of `fname` in `*file_size`.
|
||||
virtual Status GetFileSize(const string& fname, uint64* file_size) = 0;
|
||||
|
||||
/// \brief Renames file src to target. If target already exists, it will be
|
||||
/// replaced.
|
||||
virtual Status RenameFile(const string& src, const string& target) = 0;
|
||||
|
||||
// TODO(jeff,sanjay): Add back thread/thread-pool support if needed.
|
||||
// TODO(jeff,sanjay): if needed, tighten spec so relative to epoch, or
|
||||
// provide a routine to get the absolute time.
|
||||
|
||||
/// \brief Returns the number of micro-seconds since some fixed point in
|
||||
/// time. Only useful for computing deltas of time.
|
||||
virtual uint64 NowMicros() = 0;
|
||||
|
||||
/// Sleeps/delays the thread for the prescribed number of micro-seconds.
|
||||
virtual void SleepForMicroseconds(int micros) = 0;
|
||||
|
||||
/// \brief Returns a new thread that is running fn() and is identified
|
||||
/// (for debugging/performance-analysis) by "name".
|
||||
///
|
||||
/// Caller takes ownership of the result and must delete it eventually
|
||||
/// (the deletion will block until fn() stops running).
|
||||
virtual Thread* StartThread(const ThreadOptions& thread_options,
|
||||
const string& name,
|
||||
std::function<void()> fn) TF_MUST_USE_RESULT = 0;
|
||||
|
||||
// \brief Schedules the given closure on a thread-pool.
|
||||
//
|
||||
// NOTE(mrry): This closure must not block.
|
||||
virtual void SchedClosure(std::function<void()> closure) = 0;
|
||||
|
||||
// \brief Schedules the given closure on a thread-pool after the given number
|
||||
// of microseconds.
|
||||
//
|
||||
// NOTE(mrry): This closure must not block.
|
||||
virtual void SchedClosureAfter(int micros, std::function<void()> closure) = 0;
|
||||
|
||||
// \brief Load a dynamic library.
|
||||
//
|
||||
// Pass "library_filename" to a platform-specific mechanism for dynamically
|
||||
// loading a library. The rules for determining the exact location of the
|
||||
// library are platform-specific and are not documented here.
|
||||
//
|
||||
// On success, returns a handle to the library in "*handle" and returns
|
||||
// OK from the function.
|
||||
// Otherwise returns nullptr in "*handle" and an error status from the
|
||||
// function.
|
||||
virtual Status LoadLibrary(const char* library_filename, void** handle) = 0;
|
||||
|
||||
// \brief Get a pointer to a symbol from a dynamic library.
|
||||
//
|
||||
// "handle" should be a pointer returned from a previous call to LoadLibrary.
|
||||
// On success, store a pointer to the located symbol in "*symbol" and return
|
||||
// OK from the function. Otherwise, returns nullptr in "*symbol" and an error
|
||||
// status from the function.
|
||||
virtual Status GetSymbolFromLibrary(void* handle, const char* symbol_name,
|
||||
void** symbol) = 0;
|
||||
|
||||
private:
|
||||
/// No copying allowed
|
||||
Env(const Env&);
|
||||
void operator=(const Env&);
|
||||
};
|
||||
|
||||
/// A file abstraction for randomly reading the contents of a file.
|
||||
class RandomAccessFile {
|
||||
public:
|
||||
RandomAccessFile() {}
|
||||
virtual ~RandomAccessFile();
|
||||
|
||||
/// \brief Reads up to `n` bytes from the file starting at `offset`.
|
||||
///
|
||||
/// `scratch[0..n-1]` may be written by this routine. Sets `*result`
|
||||
/// to the data that was read (including if fewer than `n` bytes were
|
||||
/// successfully read). May set `*result` to point at data in
|
||||
/// `scratch[0..n-1]`, so `scratch[0..n-1]` must be live when
|
||||
/// `*result` is used.
|
||||
///
|
||||
/// On OK returned status: `n` bytes have been stored in `*result`.
|
||||
/// On non-OK returned status: `[0..n]` bytes have been stored in `*result`.
|
||||
///
|
||||
/// Returns `OUT_OF_RANGE` if fewer than n bytes were stored in `*result`
|
||||
/// because of EOF.
|
||||
///
|
||||
/// Safe for concurrent use by multiple threads.
|
||||
virtual Status Read(uint64 offset, size_t n, StringPiece* result,
|
||||
char* scratch) const = 0;
|
||||
|
||||
private:
|
||||
/// No copying allowed
|
||||
RandomAccessFile(const RandomAccessFile&);
|
||||
void operator=(const RandomAccessFile&);
|
||||
};
|
||||
|
||||
/// \brief A file abstraction for sequential writing.
|
||||
///
|
||||
/// The implementation must provide buffering since callers may append
|
||||
/// small fragments at a time to the file.
|
||||
class WritableFile {
|
||||
public:
|
||||
WritableFile() {}
|
||||
virtual ~WritableFile();
|
||||
|
||||
virtual Status Append(const StringPiece& data) = 0;
|
||||
virtual Status Close() = 0;
|
||||
virtual Status Flush() = 0;
|
||||
virtual Status Sync() = 0;
|
||||
|
||||
private:
|
||||
/// No copying allowed
|
||||
WritableFile(const WritableFile&);
|
||||
void operator=(const WritableFile&);
|
||||
};
|
||||
|
||||
/// \brief An implementation of Env that forwards all calls to another Env.
|
||||
///
|
||||
/// May be useful to clients who wish to override just part of the
|
||||
/// functionality of another Env.
|
||||
class EnvWrapper : public Env {
|
||||
public:
|
||||
/// Initializes an EnvWrapper that delegates all calls to *t
|
||||
explicit EnvWrapper(Env* t) : target_(t) {}
|
||||
virtual ~EnvWrapper();
|
||||
|
||||
/// Returns the target to which this Env forwards all calls
|
||||
Env* target() const { return target_; }
|
||||
|
||||
// The following text is boilerplate that forwards all methods to target()
|
||||
Status NewRandomAccessFile(const string& f, RandomAccessFile** r) override {
|
||||
return target_->NewRandomAccessFile(f, r);
|
||||
}
|
||||
Status NewWritableFile(const string& f, WritableFile** r) override {
|
||||
return target_->NewWritableFile(f, r);
|
||||
}
|
||||
Status NewAppendableFile(const string& f, WritableFile** r) override {
|
||||
return target_->NewAppendableFile(f, r);
|
||||
}
|
||||
bool FileExists(const string& f) override { return target_->FileExists(f); }
|
||||
Status GetChildren(const string& dir, std::vector<string>* r) override {
|
||||
return target_->GetChildren(dir, r);
|
||||
}
|
||||
Status DeleteFile(const string& f) override { return target_->DeleteFile(f); }
|
||||
Status CreateDir(const string& d) override { return target_->CreateDir(d); }
|
||||
Status DeleteDir(const string& d) override { return target_->DeleteDir(d); }
|
||||
Status GetFileSize(const string& f, uint64* s) override {
|
||||
return target_->GetFileSize(f, s);
|
||||
}
|
||||
Status RenameFile(const string& s, const string& t) override {
|
||||
return target_->RenameFile(s, t);
|
||||
}
|
||||
uint64 NowMicros() override { return target_->NowMicros(); }
|
||||
void SleepForMicroseconds(int micros) override {
|
||||
target_->SleepForMicroseconds(micros);
|
||||
}
|
||||
Thread* StartThread(const ThreadOptions& thread_options, const string& name,
|
||||
std::function<void()> fn) override {
|
||||
return target_->StartThread(thread_options, name, fn);
|
||||
}
|
||||
void SchedClosure(std::function<void()> closure) override {
|
||||
target_->SchedClosure(closure);
|
||||
}
|
||||
void SchedClosureAfter(int micros, std::function<void()> closure) override {
|
||||
target_->SchedClosureAfter(micros, closure);
|
||||
}
|
||||
Status LoadLibrary(const char* library_filename, void** handle) override {
|
||||
return target_->LoadLibrary(library_filename, handle);
|
||||
}
|
||||
Status GetSymbolFromLibrary(void* handle, const char* symbol_name,
|
||||
void** symbol) override {
|
||||
return target_->GetSymbolFromLibrary(handle, symbol_name, symbol);
|
||||
}
|
||||
|
||||
private:
|
||||
Env* target_;
|
||||
};
|
||||
|
||||
class Thread {
|
||||
public:
|
||||
Thread() {}
|
||||
|
||||
/// Blocks until the thread of control stops running.
|
||||
virtual ~Thread();
|
||||
|
||||
private:
|
||||
/// No copying allowed
|
||||
Thread(const Thread&);
|
||||
void operator=(const Thread&);
|
||||
};
|
||||
|
||||
/// \brief Options to configure a Thread.
|
||||
///
|
||||
/// Note that the options are all hints, and the
|
||||
/// underlying implementation may choose to ignore it.
|
||||
struct ThreadOptions {
|
||||
/// Thread stack size to use (in bytes).
|
||||
size_t stack_size = 0; // 0: use system default value
|
||||
/// Guard area size to use near thread stacks to use (in bytes)
|
||||
size_t guard_size = 0; // 0: use system default value
|
||||
};
|
||||
|
||||
/// A utility routine: reads contents of named file into `*data`
|
||||
Status ReadFileToString(Env* env, const string& fname, string* data);
|
||||
|
||||
/// A utility routine: write contents of `data` to file named `fname`
|
||||
/// (overwriting existing contents, if any).
|
||||
Status WriteStringToFile(Env* env, const string& fname,
|
||||
const StringPiece& data);
|
||||
|
||||
/// Reads contents of named file and parse as binary encoded proto data
|
||||
/// and store into `*proto`.
|
||||
Status ReadBinaryProto(Env* env, const string& fname,
|
||||
::tensorflow::protobuf::MessageLite* proto);
|
||||
|
||||
} // namespace tensorflow
|
||||
// This file is deprecated, use ../platform/env.h instead.
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
|
||||
#endif // TENSORFLOW_PUBLIC_ENV_H_
|
||||
|
@ -16,155 +16,6 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_PUBLIC_PARTIAL_TENSOR_SHAPE_H_
|
||||
#define TENSORFLOW_PUBLIC_PARTIAL_TENSOR_SHAPE_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/public/status.h"
|
||||
#include "tensorflow/core/public/tensor_shape.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class PartialTensorShapeIter; // Declared below
|
||||
|
||||
/// Manages the partially known dimensions of a Tensor and their sizes.
|
||||
class PartialTensorShape {
|
||||
public:
|
||||
/// \brief Construct a `PartialTensorShape` from the provided sizes.
|
||||
/// REQUIRES: `dim_sizes[i] >= 0`
|
||||
explicit PartialTensorShape(gtl::ArraySlice<int64> dim_sizes);
|
||||
PartialTensorShape(std::initializer_list<int64> dim_sizes)
|
||||
: PartialTensorShape(gtl::ArraySlice<int64>(dim_sizes)) {}
|
||||
|
||||
/// REQUIRES: `IsValid(proto)`
|
||||
explicit PartialTensorShape(const TensorShapeProto& proto);
|
||||
|
||||
/// Returns `true` iff `proto` is a valid partial tensor shape.
|
||||
static bool IsValid(const TensorShapeProto& proto);
|
||||
|
||||
/// Returns `OK` iff `proto` is a valid tensor shape, and a descriptive error
|
||||
/// status otherwise.
|
||||
static Status IsValidShape(const TensorShapeProto& proto);
|
||||
|
||||
/// Add a dimension to the end ("inner-most"), returns a new
|
||||
/// PartialTensorShape.
|
||||
/// REQUIRES: `size >= -1`, where -1 means unknown
|
||||
PartialTensorShape Concatenate(int64 size) const;
|
||||
|
||||
/// Appends all the dimensions from `shape`. Returns a new
|
||||
/// PartialTensorShape.
|
||||
PartialTensorShape Concatenate(const PartialTensorShape& shape) const;
|
||||
|
||||
/// Merges all the dimensions from `shape`. Returns
|
||||
/// `InvalidArgument` error if either `shape` has a different rank
|
||||
/// or if any of the dimensions are incompatible.
|
||||
Status MergeWith(const PartialTensorShape& shape,
|
||||
PartialTensorShape* result) const;
|
||||
|
||||
/// Return the number of dimensions in the tensor.
|
||||
int dims() const { return dim_sizes_.size(); }
|
||||
|
||||
/// Return true iff the rank and all of the dimensions are well defined
|
||||
bool IsFullyDefined() const;
|
||||
|
||||
/// Return true iff the ranks match, and if the
|
||||
/// dimensions all either match or one is unknown.
|
||||
bool IsCompatibleWith(const PartialTensorShape& shape) const;
|
||||
|
||||
/// Return true iff the dimensions of `shape` are compatible with
|
||||
/// `*this`.
|
||||
bool IsCompatibleWith(const TensorShape& shape) const;
|
||||
|
||||
/// \brief Returns the number of elements in dimension `d`.
|
||||
/// REQUIRES: `0 <= d < dims()`
|
||||
int64 dim_size(int d) const {
|
||||
DCHECK_GE(d, 0);
|
||||
DCHECK_LT(d, dims());
|
||||
return dim_sizes_[d];
|
||||
}
|
||||
|
||||
/// Returns sizes of all dimensions.
|
||||
gtl::ArraySlice<int64> dim_sizes() const { return dim_sizes_; }
|
||||
|
||||
/// Fill `*proto` from `*this`.
|
||||
void AsProto(TensorShapeProto* proto) const;
|
||||
|
||||
// Fill `*tensor_shape` from `*this`.
|
||||
// If `*this` is not fully defined, returns false and
|
||||
// `*tensor_shape` is left in an intermediate state. Otherwise
|
||||
// returns true.
|
||||
bool AsTensorShape(TensorShape* tensor_shape) const;
|
||||
|
||||
/// For error messages.
|
||||
string DebugString() const;
|
||||
static string DebugString(const TensorShapeProto& proto);
|
||||
|
||||
/// \brief Returns a `PartialTensorShape` whose dimensions are
|
||||
/// `dims[0]`, `dims[1]`, ..., `dims[n-1]`. Values of -1 are
|
||||
/// considered "unknown".
|
||||
template <typename T>
|
||||
static Status MakePartialShape(const T* dims, int n, PartialTensorShape* out);
|
||||
|
||||
private:
|
||||
/// Create a tensor shape.
|
||||
PartialTensorShape();
|
||||
|
||||
gtl::InlinedVector<int64, 4> dim_sizes_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
Status PartialTensorShape::MakePartialShape(const T* dims, int n,
|
||||
PartialTensorShape* out) {
|
||||
*out = PartialTensorShape();
|
||||
out->dim_sizes_.reserve(n);
|
||||
for (int i = 0; i < n; ++i) {
|
||||
if (dims[i] >= -1) {
|
||||
out->dim_sizes_.push_back(dims[i]);
|
||||
} else {
|
||||
return errors::InvalidArgument("Dimension ", dims[i], " must be >= -1");
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/// \brief Static helper routines for `PartialTensorShape`. Includes a few
|
||||
/// common predicates on a partially known tensor shape.
|
||||
class PartialTensorShapeUtils {
|
||||
public:
|
||||
static string PartialShapeListString(
|
||||
const gtl::ArraySlice<PartialTensorShape>& shapes) {
|
||||
string result = "[";
|
||||
bool first = true;
|
||||
for (const PartialTensorShape& shape : shapes) {
|
||||
strings::StrAppend(&result, (first ? "" : ", "), shape.DebugString());
|
||||
first = false;
|
||||
}
|
||||
strings::StrAppend(&result, "]");
|
||||
return result;
|
||||
}
|
||||
|
||||
static bool AreCompatible(
|
||||
const gtl::ArraySlice<PartialTensorShape>& shapes0,
|
||||
const gtl::ArraySlice<PartialTensorShape>& shapes1) {
|
||||
if (shapes0.size() == shapes1.size()) {
|
||||
for (int i = 0; i < shapes0.size(); ++i) {
|
||||
if (!shapes0[i].IsCompatibleWith(shapes1[i])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||
|
||||
#endif // TENSORFLOW_PUBLIC_PARTIAL_TENSOR_SHAPE_H_
|
||||
|
@ -16,97 +16,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_PUBLIC_STATUS_H_
|
||||
#define TENSORFLOW_PUBLIC_STATUS_H_
|
||||
|
||||
#include <functional>
|
||||
#include <iosfwd>
|
||||
#include <string>
|
||||
#include "tensorflow/core/lib/core/error_codes.pb.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class Status {
|
||||
public:
|
||||
/// Create a success status.
|
||||
Status() : state_(NULL) {}
|
||||
~Status() { delete state_; }
|
||||
|
||||
/// \brief Create a status with the specified error code and msg as a
|
||||
/// human-readable string containing more detailed information.
|
||||
Status(tensorflow::error::Code code, tensorflow::StringPiece msg);
|
||||
|
||||
/// Copy the specified status.
|
||||
Status(const Status& s);
|
||||
void operator=(const Status& s);
|
||||
|
||||
static Status OK() { return Status(); }
|
||||
|
||||
/// Returns true iff the status indicates success.
|
||||
bool ok() const { return (state_ == NULL); }
|
||||
|
||||
tensorflow::error::Code code() const {
|
||||
return ok() ? tensorflow::error::OK : state_->code;
|
||||
}
|
||||
|
||||
const string& error_message() const {
|
||||
return ok() ? empty_string() : state_->msg;
|
||||
}
|
||||
|
||||
bool operator==(const Status& x) const;
|
||||
bool operator!=(const Status& x) const;
|
||||
|
||||
/// \brief If `ok()`, stores `new_status` into `*this`. If `!ok()`,
|
||||
/// preserves the current status, but may augment with additional
|
||||
/// information about `new_status`.
|
||||
///
|
||||
/// Convenient way of keeping track of the first error encountered.
|
||||
/// Instead of:
|
||||
/// `if (overall_status.ok()) overall_status = new_status`
|
||||
/// Use:
|
||||
/// `overall_status.Update(new_status);`
|
||||
void Update(const Status& new_status);
|
||||
|
||||
/// \brief Return a string representation of this status suitable for
|
||||
/// printing. Returns the string `"OK"` for success.
|
||||
string ToString() const;
|
||||
|
||||
private:
|
||||
static const string& empty_string();
|
||||
struct State {
|
||||
tensorflow::error::Code code;
|
||||
string msg;
|
||||
};
|
||||
// OK status has a `NULL` state_. Otherwise, `state_` points to
|
||||
// a `State` structure containing the error code and message(s)
|
||||
State* state_;
|
||||
|
||||
void SlowCopyFrom(const State* src);
|
||||
};
|
||||
|
||||
inline Status::Status(const Status& s)
|
||||
: state_((s.state_ == NULL) ? NULL : new State(*s.state_)) {}
|
||||
|
||||
inline void Status::operator=(const Status& s) {
|
||||
// The following condition catches both aliasing (when this == &s),
|
||||
// and the common case where both s and *this are ok.
|
||||
if (state_ != s.state_) {
|
||||
SlowCopyFrom(s.state_);
|
||||
}
|
||||
}
|
||||
|
||||
inline bool Status::operator==(const Status& x) const {
|
||||
return (this->state_ == x.state_) || (ToString() == x.ToString());
|
||||
}
|
||||
|
||||
inline bool Status::operator!=(const Status& x) const { return !(*this == x); }
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const Status& x);
|
||||
|
||||
typedef std::function<void(const Status&)> StatusCallback;
|
||||
|
||||
#define TF_CHECK_OK(val) CHECK_EQ(::tensorflow::Status::OK(), (val))
|
||||
#define TF_QCHECK_OK(val) QCHECK_EQ(::tensorflow::Status::OK(), (val))
|
||||
|
||||
} // namespace tensorflow
|
||||
// This file is deprecated, use ../lib/core/status.h instead.
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
#endif // TENSORFLOW_PUBLIC_STATUS_H_
|
||||
|
@ -16,498 +16,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_PUBLIC_TENSOR_H_
|
||||
#define TENSORFLOW_PUBLIC_TENSOR_H_
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/allocation_description.pb.h"
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_description.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/port.h"
|
||||
#include "tensorflow/core/public/status.h"
|
||||
#include "tensorflow/core/public/tensor_shape.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class TensorBuffer; // Forward declaration.
|
||||
class TensorCApi;
|
||||
|
||||
/// Represents an n-dimensional array of values.
|
||||
class Tensor {
|
||||
public:
|
||||
/// Default Tensor constructor. Creates a 1-dimension, 0-element float tensor.
|
||||
Tensor();
|
||||
|
||||
/// \brief Creates a Tensor of the given `type` and `shape`.
|
||||
///
|
||||
/// The underlying buffer is allocated using a `CPUAllocator`.
|
||||
Tensor(DataType type, const TensorShape& shape);
|
||||
|
||||
/// \brief Creates a tensor with the input `type` and `shape`, using the
|
||||
/// allocator `a` to allocate the underlying buffer.
|
||||
///
|
||||
/// `a` must outlive the lifetime of this Tensor.
|
||||
Tensor(Allocator* a, DataType type, const TensorShape& shape);
|
||||
|
||||
/// \brief Creates a tensor with the input `type` and `shape`, using the
|
||||
/// allocator `a` and the specified "allocation_attr" to allocate the
|
||||
/// underlying buffer.
|
||||
///
|
||||
/// `a` must outlive the lifetime of this Tensor.
|
||||
Tensor(Allocator* a, DataType type, const TensorShape& shape,
|
||||
const AllocationAttributes& allocation_attr);
|
||||
|
||||
/// Creates an uninitialized Tensor of the given data type.
|
||||
explicit Tensor(DataType type);
|
||||
|
||||
Tensor(const Tensor& other); /// Copy constructor.
|
||||
|
||||
~Tensor();
|
||||
|
||||
/// Returns the data type.
|
||||
DataType dtype() const { return type_; }
|
||||
|
||||
/// Returns the shape of the tensor.
|
||||
const TensorShape& shape() const { return shape_; }
|
||||
|
||||
/// \brief Convenience accessor for the tensor shape.
|
||||
///
|
||||
/// For all shape accessors, see comments for relevant methods of
|
||||
/// `TensorShape` in `tensor_shape.h`.
|
||||
int dims() const { return shape().dims(); }
|
||||
|
||||
/// Convenience accessor for the tensor shape.
|
||||
int64 dim_size(int d) const { return shape().dim_size(d); }
|
||||
|
||||
/// Convenience accessor for the tensor shape.
|
||||
int64 NumElements() const { return shape().num_elements(); }
|
||||
|
||||
bool IsSameSize(const Tensor& b) const {
|
||||
return shape().IsSameSize(b.shape());
|
||||
}
|
||||
|
||||
// True iff the two tensors use the same underlying refcounted storage
|
||||
bool SharesBufferWith(const Tensor& b) const;
|
||||
|
||||
// The BufferHash of two tensors are equal when they share the same
|
||||
// underlying refcounted storage
|
||||
size_t BufferHash() const;
|
||||
|
||||
/// Has this Tensor been initialized?
|
||||
bool IsInitialized() const;
|
||||
|
||||
/// Returns the estimated memory usage of this tensor.
|
||||
size_t TotalBytes() const;
|
||||
|
||||
/// Assign operator. This tensor shares other's underlying storage.
|
||||
Tensor& operator=(const Tensor& other) {
|
||||
CopyFromInternal(other, other.shape());
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// \brief Copy the other tensor into this tensor and reshape it.
|
||||
///
|
||||
/// This tensor shares other's underlying storage. Returns `true`
|
||||
/// iff `other.shape()` has the same number of elements of the given
|
||||
/// `shape`.
|
||||
bool CopyFrom(const Tensor& other,
|
||||
const TensorShape& shape) TF_MUST_USE_RESULT {
|
||||
if (other.NumElements() != shape.num_elements()) return false;
|
||||
CopyFromInternal(other, shape);
|
||||
return true;
|
||||
}
|
||||
|
||||
/// \brief Slice this tensor along the 1st dimension.
|
||||
|
||||
/// I.e., the returned tensor satisfies
|
||||
/// returned[i, ...] == this[dim0_start + i, ...].
|
||||
/// The returned tensor shares the underlying tensor buffer with this
|
||||
/// tensor.
|
||||
///
|
||||
/// NOTE: The returned tensor may not satisfies the same alignment
|
||||
/// requirement as this tensor depending on the shape. The caller
|
||||
/// must check the returned tensor's alignment before calling certain
|
||||
/// methods that have alignment requirement (e.g., `flat()`, `tensor()`).
|
||||
///
|
||||
/// REQUIRES: `dims()` >= 1
|
||||
/// REQUIRES: `0 <= dim0_start <= dim0_limit <= dim_size(0)`
|
||||
Tensor Slice(int64 dim0_start, int64 dim0_limit) const;
|
||||
|
||||
/// \brief Parse `other` and construct the tensor.
|
||||
|
||||
/// Returns `true` iff the parsing succeeds. If the parsing fails,
|
||||
/// the state of `*this` is unchanged.
|
||||
bool FromProto(const TensorProto& other) TF_MUST_USE_RESULT;
|
||||
bool FromProto(Allocator* a, const TensorProto& other) TF_MUST_USE_RESULT;
|
||||
|
||||
/// \brief Fills in `proto` with `*this` tensor's content.
|
||||
///
|
||||
/// `AsProtoField()` fills in the repeated field for `proto.dtype()`, while
|
||||
/// `AsProtoTensorContent()` encodes the content in `proto.tensor_content()`
|
||||
/// in a compact form.
|
||||
void AsProtoField(TensorProto* proto) const;
|
||||
void AsProtoTensorContent(TensorProto* proto) const;
|
||||
|
||||
/// \brief Return the tensor data as an `Eigen::Tensor` with the type and
|
||||
/// sizes of this `Tensor`.
|
||||
///
|
||||
/// Use these methods when you know the data type and the number of
|
||||
/// dimensions of the Tensor and you want an `Eigen::Tensor`
|
||||
/// automatically sized to the `Tensor` sizes. The implementation check
|
||||
/// fails if either type or sizes mismatch.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// ```c++
|
||||
///
|
||||
/// typedef float T;
|
||||
/// Tensor my_mat(...built with Shape{rows: 3, cols: 5}...);
|
||||
/// auto mat = my_mat.matrix<T>(); // 2D Eigen::Tensor, 3 x 5.
|
||||
/// auto mat = my_mat.tensor<T, 2>(); // 2D Eigen::Tensor, 3 x 5.
|
||||
/// auto vec = my_mat.vec<T>(); // CHECK fails as my_mat is 2D.
|
||||
/// auto vec = my_mat.tensor<T, 3>(); // CHECK fails as my_mat is 2D.
|
||||
/// auto mat = my_mat.matrix<int32>();// CHECK fails as type mismatch.
|
||||
///
|
||||
/// ```
|
||||
template <typename T>
|
||||
typename TTypes<T>::Vec vec() {
|
||||
return tensor<T, 1>();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
typename TTypes<T>::Matrix matrix() {
|
||||
return tensor<T, 2>();
|
||||
}
|
||||
|
||||
template <typename T, size_t NDIMS>
|
||||
typename TTypes<T, NDIMS>::Tensor tensor();
|
||||
|
||||
/// \brief Return the tensor data as an `Eigen::Tensor` of the data type and a
|
||||
/// specified shape.
|
||||
///
|
||||
/// These methods allow you to access the data with the dimensions
|
||||
/// and sizes of your choice. You do not need to know the number of
|
||||
/// dimensions of the Tensor to call them. However, they `CHECK` that
|
||||
/// the type matches and the dimensions requested creates an
|
||||
/// `Eigen::Tensor` with the same number of elements as the tensor.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// ```c++
|
||||
///
|
||||
/// typedef float T;
|
||||
/// Tensor my_ten(...built with Shape{planes: 4, rows: 3, cols: 5}...);
|
||||
/// // 1D Eigen::Tensor, size 60:
|
||||
/// auto flat = my_ten.flat<T>();
|
||||
/// // 2D Eigen::Tensor 12 x 5:
|
||||
/// auto inner = my_ten.flat_inner_dims<T>();
|
||||
/// // 2D Eigen::Tensor 4 x 15:
|
||||
/// auto outer = my_ten.shaped<T, 2>({4, 15});
|
||||
/// // CHECK fails, bad num elements:
|
||||
/// auto outer = my_ten.shaped<T, 2>({4, 8});
|
||||
/// // 3D Eigen::Tensor 6 x 5 x 2:
|
||||
/// auto weird = my_ten.shaped<T, 3>({6, 5, 2});
|
||||
/// // CHECK fails, type mismatch:
|
||||
/// auto bad = my_ten.flat<int32>();
|
||||
///
|
||||
/// ```
|
||||
template <typename T>
|
||||
typename TTypes<T>::Flat flat() {
|
||||
return shaped<T, 1>({NumElements()});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
typename TTypes<T>::UnalignedFlat unaligned_flat() {
|
||||
return unaligned_shaped<T, 1>({NumElements()});
|
||||
}
|
||||
|
||||
/// Returns the data as an Eigen::Tensor with 2 dimensions, collapsing all
|
||||
/// Tensor dimensions but the last one into the first dimension of the result.
|
||||
template <typename T>
|
||||
typename TTypes<T>::Matrix flat_inner_dims() {
|
||||
int64 last_size = dims() > 0 ? dim_size(dims() - 1) : 1;
|
||||
if (last_size == 0) {
|
||||
DCHECK_EQ(NumElements(), 0);
|
||||
// Return something empty, avoiding divide by 0
|
||||
return shaped<T, 2>({0, 0});
|
||||
} else {
|
||||
return shaped<T, 2>({NumElements() / last_size, last_size});
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the data as an Eigen::Tensor with 2 dimensions, collapsing all
|
||||
/// Tensor dimensions but the first one into the last dimension of the result.
|
||||
template <typename T>
|
||||
typename TTypes<T>::Matrix flat_outer_dims() {
|
||||
int64 first_size = dims() > 0 ? dim_size(0) : 1;
|
||||
if (first_size == 0) {
|
||||
DCHECK_EQ(NumElements(), 0);
|
||||
// Return something empty, avoiding divide by 0
|
||||
return shaped<T, 2>({0, 0});
|
||||
} else {
|
||||
return shaped<T, 2>({first_size, NumElements() / first_size});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, size_t NDIMS>
|
||||
typename TTypes<T, NDIMS>::Tensor shaped(gtl::ArraySlice<int64> new_sizes);
|
||||
|
||||
template <typename T, size_t NDIMS>
|
||||
typename TTypes<T, NDIMS>::UnalignedTensor unaligned_shaped(
|
||||
gtl::ArraySlice<int64> new_sizes);
|
||||
|
||||
/// \brief Return the Tensor data as a `TensorMap` of fixed size 1:
|
||||
/// `TensorMap<TensorFixedSize<T, 1>>`.
|
||||
|
||||
/// Using `scalar()` allows the compiler to perform optimizations as
|
||||
/// the size of the tensor is known at compile time.
|
||||
template <typename T>
|
||||
typename TTypes<T>::Scalar scalar();
|
||||
|
||||
/// Const versions of all the methods above.
|
||||
template <typename T>
|
||||
typename TTypes<T>::ConstVec vec() const {
|
||||
return tensor<T, 1>();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
typename TTypes<T>::ConstMatrix matrix() const {
|
||||
return tensor<T, 2>();
|
||||
}
|
||||
|
||||
template <typename T, size_t NDIMS>
|
||||
typename TTypes<T, NDIMS>::ConstTensor tensor() const;
|
||||
|
||||
template <typename T>
|
||||
typename TTypes<T>::ConstFlat flat() const {
|
||||
return shaped<T, 1>({NumElements()});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
typename TTypes<T>::UnalignedConstFlat unaligned_flat() const {
|
||||
return unaligned_shaped<T, 1>({NumElements()});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
typename TTypes<T>::ConstMatrix flat_inner_dims() const {
|
||||
int64 last_size = dims() > 0 ? dim_size(dims() - 1) : 1;
|
||||
if (last_size == 0) {
|
||||
DCHECK_EQ(NumElements(), 0);
|
||||
// Return something empty, avoiding divide by 0
|
||||
return shaped<T, 2>({0, 0});
|
||||
} else {
|
||||
return shaped<T, 2>({NumElements() / last_size, last_size});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
typename TTypes<T>::ConstMatrix flat_outer_dims() const {
|
||||
int64 first_size = dims() > 0 ? dim_size(0) : 1;
|
||||
if (first_size == 0) {
|
||||
DCHECK_EQ(NumElements(), 0);
|
||||
// Return something empty, avoiding divide by 0
|
||||
return shaped<T, 2>({0, 0});
|
||||
} else {
|
||||
return shaped<T, 2>({first_size, NumElements() / first_size});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, size_t NDIMS>
|
||||
typename TTypes<T, NDIMS>::ConstTensor shaped(
|
||||
gtl::ArraySlice<int64> new_sizes) const;
|
||||
template <typename T, size_t NDIMS>
|
||||
typename TTypes<T, NDIMS>::UnalignedConstTensor unaligned_shaped(
|
||||
gtl::ArraySlice<int64> new_sizes) const;
|
||||
|
||||
template <typename T>
|
||||
typename TTypes<T>::ConstScalar scalar() const;
|
||||
|
||||
/// Render the first `max_entries` values in `*this` into a string.
|
||||
string SummarizeValue(int64 max_entries) const;
|
||||
|
||||
/// A human-readable summary of the tensor suitable for debugging.
|
||||
string DebugString() const;
|
||||
|
||||
/// Fill in the `TensorDescription` proto with metadata about the
|
||||
/// tensor that is useful for monitoring and debugging.
|
||||
void FillDescription(TensorDescription* description) const;
|
||||
|
||||
/// \brief Returns a `StringPiece` mapping the current tensor's buffer.
|
||||
///
|
||||
/// The returned `StringPiece` may point to memory location on devices
|
||||
/// that the CPU cannot address directly.
|
||||
///
|
||||
/// NOTE: The underlying tensor buffer is refcounted, so the lifetime
|
||||
/// of the contents mapped by the `StringPiece` matches the lifetime of
|
||||
/// the buffer; callers should arrange to make sure the buffer does
|
||||
/// not get destroyed while the `StringPiece` is still used.
|
||||
///
|
||||
/// REQUIRES: `DataTypeCanUseMemcpy(dtype())`.
|
||||
StringPiece tensor_data() const;
|
||||
|
||||
private:
|
||||
DataType type_;
|
||||
TensorShape shape_;
|
||||
TensorBuffer* buf_;
|
||||
|
||||
friend class DMAHelper;
|
||||
friend class TensorCApi;
|
||||
friend class TensorReference; // For access to buf_
|
||||
friend class VariableOp; // For access to set_shape
|
||||
friend class AutoReloadVariableOp; // For access to set_shape
|
||||
|
||||
// Creates a tensor with the input datatype, shape and buf.
|
||||
//
|
||||
// Acquires a ref on buf that belongs to this Tensor.
|
||||
Tensor(DataType type, const TensorShape& shape, TensorBuffer* buf);
|
||||
|
||||
bool CanUseDMA() const;
|
||||
|
||||
// Only needed by variable op to set the shape of an uninitialized
|
||||
// Tensor.
|
||||
// TODO: Remove this when we have a better story for detecting
|
||||
// uninitialized tensors.
|
||||
void set_shape(const TensorShape& shape) { shape_ = shape; }
|
||||
|
||||
void CopyFromInternal(const Tensor& other, const TensorShape& shape);
|
||||
|
||||
template <typename T>
|
||||
T* base() const;
|
||||
};
|
||||
|
||||
// Implementation details
|
||||
|
||||
// Interface to access the raw ref-counted data buffer.
|
||||
class TensorBuffer : public core::RefCounted {
|
||||
public:
|
||||
~TensorBuffer() override {}
|
||||
|
||||
// data() points to a memory region of size() bytes.
|
||||
virtual void* data() const = 0;
|
||||
virtual size_t size() const = 0;
|
||||
|
||||
// If this TensorBuffer is sub-buffer of another TensorBuffer,
|
||||
// returns that TensorBuffer. Otherwise, returns this.
|
||||
virtual TensorBuffer* root_buffer() = 0;
|
||||
|
||||
// Fill metadata about the allocation into the proto.
|
||||
virtual void FillAllocationDescription(
|
||||
AllocationDescription* proto) const = 0;
|
||||
|
||||
template <typename T>
|
||||
T* base() const {
|
||||
return reinterpret_cast<T*>(data());
|
||||
}
|
||||
};
|
||||
|
||||
inline void CheckEigenAlignment(const void* ptr) {
|
||||
#if EIGEN_ALIGN == 1
|
||||
CHECK_EQ(reinterpret_cast<intptr_t>(ptr) % EIGEN_ALIGN_BYTES, 0);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T* Tensor::base() const {
|
||||
return buf_ == nullptr ? nullptr : buf_->base<T>();
|
||||
}
|
||||
|
||||
template <typename T, size_t NDIMS>
|
||||
typename TTypes<T, NDIMS>::Tensor Tensor::tensor() {
|
||||
CHECK_EQ(dtype(), DataTypeToEnum<T>::v());
|
||||
CheckEigenAlignment(base<T>());
|
||||
return typename TTypes<T, NDIMS>::Tensor(base<T>(),
|
||||
shape().AsEigenDSizes<NDIMS>());
|
||||
}
|
||||
|
||||
template <typename T, size_t NDIMS>
|
||||
typename TTypes<T, NDIMS>::ConstTensor Tensor::tensor() const {
|
||||
CheckEigenAlignment(base<T>());
|
||||
CHECK_EQ(dtype(), DataTypeToEnum<T>::v());
|
||||
return typename TTypes<T, NDIMS>::ConstTensor(base<const T>(),
|
||||
shape().AsEigenDSizes<NDIMS>());
|
||||
}
|
||||
|
||||
template <typename T, size_t NDIMS>
|
||||
typename TTypes<T, NDIMS>::Tensor Tensor::shaped(
|
||||
gtl::ArraySlice<int64> new_sizes) {
|
||||
CheckEigenAlignment(base<T>());
|
||||
CHECK_EQ(dtype(), DataTypeToEnum<T>::v());
|
||||
CHECK_EQ(NDIMS, new_sizes.size());
|
||||
int64 new_num_elements = 1;
|
||||
Eigen::array<Eigen::DenseIndex, NDIMS> dims;
|
||||
for (int d = 0; d < NDIMS; d++) {
|
||||
new_num_elements *= new_sizes[d];
|
||||
dims[d] = new_sizes[d];
|
||||
}
|
||||
CHECK_EQ(new_num_elements, NumElements());
|
||||
return typename TTypes<T, NDIMS>::Tensor(base<T>(), dims);
|
||||
}
|
||||
|
||||
template <typename T, size_t NDIMS>
|
||||
typename TTypes<T, NDIMS>::UnalignedTensor Tensor::unaligned_shaped(
|
||||
gtl::ArraySlice<int64> new_sizes) {
|
||||
CHECK_EQ(dtype(), DataTypeToEnum<T>::v());
|
||||
CHECK_EQ(NDIMS, new_sizes.size());
|
||||
int64 new_num_elements = 1;
|
||||
Eigen::array<Eigen::DenseIndex, NDIMS> dims;
|
||||
for (int d = 0; d < NDIMS; d++) {
|
||||
new_num_elements *= new_sizes[d];
|
||||
dims[d] = new_sizes[d];
|
||||
}
|
||||
CHECK_EQ(new_num_elements, NumElements());
|
||||
return typename TTypes<T, NDIMS>::UnalignedTensor(base<T>(), dims);
|
||||
}
|
||||
|
||||
template <typename T, size_t NDIMS>
|
||||
typename TTypes<T, NDIMS>::ConstTensor Tensor::shaped(
|
||||
gtl::ArraySlice<int64> new_sizes) const {
|
||||
CheckEigenAlignment(base<T>());
|
||||
CHECK_EQ(dtype(), DataTypeToEnum<T>::v());
|
||||
CHECK_EQ(NDIMS, new_sizes.size());
|
||||
int64 new_num_elements = 1;
|
||||
Eigen::array<Eigen::DenseIndex, NDIMS> dims;
|
||||
for (int d = 0; d < NDIMS; d++) {
|
||||
new_num_elements *= new_sizes[d];
|
||||
dims[d] = new_sizes[d];
|
||||
}
|
||||
CHECK_EQ(new_num_elements, NumElements());
|
||||
return typename TTypes<T, NDIMS>::ConstTensor(base<T>(), dims);
|
||||
}
|
||||
|
||||
template <typename T, size_t NDIMS>
|
||||
typename TTypes<T, NDIMS>::UnalignedConstTensor Tensor::unaligned_shaped(
|
||||
gtl::ArraySlice<int64> new_sizes) const {
|
||||
CHECK_EQ(dtype(), DataTypeToEnum<T>::v());
|
||||
CHECK_EQ(NDIMS, new_sizes.size());
|
||||
int64 new_num_elements = 1;
|
||||
Eigen::array<Eigen::DenseIndex, NDIMS> dims;
|
||||
for (int d = 0; d < NDIMS; d++) {
|
||||
new_num_elements *= new_sizes[d];
|
||||
dims[d] = new_sizes[d];
|
||||
}
|
||||
CHECK_EQ(new_num_elements, NumElements());
|
||||
return typename TTypes<T, NDIMS>::UnalignedConstTensor(base<T>(), dims);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
typename TTypes<T>::Scalar Tensor::scalar() {
|
||||
CheckEigenAlignment(base<T>());
|
||||
CHECK_EQ(1, NumElements()) << "Must have a one element tensor";
|
||||
return typename TTypes<T>::Scalar(base<T>());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
typename TTypes<T>::ConstScalar Tensor::scalar() const {
|
||||
CheckEigenAlignment(base<T>());
|
||||
CHECK_EQ(1, NumElements()) << "Must have a one element tensor";
|
||||
return typename TTypes<T>::ConstScalar(base<T>());
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
// This file is deprecated, use ../framework/tensor.h instead.
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
|
||||
#endif // TENSORFLOW_PUBLIC_TENSOR_H_
|
||||
|
@ -16,233 +16,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_PUBLIC_TENSOR_SHAPE_H_
|
||||
#define TENSORFLOW_PUBLIC_TENSOR_SHAPE_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/public/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class TensorShapeIter; // Declared below
|
||||
|
||||
/// Manages the dimensions of a Tensor and their sizes.
|
||||
class TensorShape {
|
||||
public:
|
||||
/// \brief Construct a `TensorShape` from the provided sizes.
|
||||
/// REQUIRES: `dim_sizes[i] >= 0`
|
||||
explicit TensorShape(gtl::ArraySlice<int64> dim_sizes);
|
||||
TensorShape(std::initializer_list<int64> dim_sizes)
|
||||
: TensorShape(gtl::ArraySlice<int64>(dim_sizes)) {}
|
||||
|
||||
/// REQUIRES: `IsValid(proto)`
|
||||
explicit TensorShape(const TensorShapeProto& proto);
|
||||
|
||||
/// Create a tensor shape with no dimensions and one element, which you can
|
||||
/// then call `AddDim()` on.
|
||||
TensorShape();
|
||||
|
||||
/// Returns `true` iff `proto` is a valid tensor shape.
|
||||
static bool IsValid(const TensorShapeProto& proto);
|
||||
|
||||
/// Returns `OK` iff `proto` is a valid tensor shape, and a descriptive error
|
||||
/// status otherwise.
|
||||
static Status IsValidShape(const TensorShapeProto& proto);
|
||||
|
||||
/// Clear a tensor shape
|
||||
void Clear();
|
||||
|
||||
/// \brief Add a dimension to the end ("inner-most").
|
||||
/// REQUIRES: `size >= 0`
|
||||
void AddDim(int64 size);
|
||||
|
||||
/// Appends all the dimensions from `shape`.
|
||||
void AppendShape(const TensorShape& shape);
|
||||
|
||||
/// \brief Insert a dimension somewhere in the `TensorShape`.
|
||||
/// REQUIRES: `0 <= d <= dims()`
|
||||
/// REQUIRES: `size >= 0`
|
||||
void InsertDim(int d, int64 size);
|
||||
|
||||
/// \brief Modifies the size of the dimension `d` to be `size`
|
||||
/// REQUIRES: `0 <= d < dims()`
|
||||
/// REQUIRES: `size >= 0`
|
||||
void set_dim(int d, int64 size);
|
||||
|
||||
/// \brief Removes dimension `d` from the `TensorShape`.
|
||||
/// REQUIRES: `0 <= d < dims()`
|
||||
void RemoveDim(int d);
|
||||
|
||||
/// Return the number of dimensions in the tensor.
|
||||
int dims() const { return dim_sizes_.size(); }
|
||||
|
||||
/// \brief Returns the number of elements in dimension `d`.
|
||||
/// REQUIRES: `0 <= d < dims()`
|
||||
// TODO(touts): Rename to `dimension()` to match
|
||||
// `Eigen::Tensor::dimension()`?
|
||||
int64 dim_size(int d) const {
|
||||
DCHECK_GE(d, 0);
|
||||
DCHECK_LT(d, dims());
|
||||
return dim_sizes_[d];
|
||||
}
|
||||
|
||||
/// Returns sizes of all dimensions.
|
||||
gtl::ArraySlice<int64> dim_sizes() const { return dim_sizes_; }
|
||||
|
||||
/// \brief Returns the number of elements in the tensor.
|
||||
///
|
||||
/// We use `int64` and not `size_t` to be compatible with `Eigen::Tensor`
|
||||
/// which uses `ptrdiff_t`.
|
||||
int64 num_elements() const { return num_elements_; }
|
||||
|
||||
/// Returns true if `*this` and `b` have the same sizes. Ignores
|
||||
/// dimension names.
|
||||
bool IsSameSize(const TensorShape& b) const;
|
||||
bool operator==(const TensorShape& b) const { return IsSameSize(b); }
|
||||
|
||||
/// Fill `*proto` from `*this`.
|
||||
void AsProto(TensorShapeProto* proto) const;
|
||||
|
||||
/// Fill `*dsizes` from `*this`.
|
||||
template <int NDIMS>
|
||||
Eigen::DSizes<Eigen::DenseIndex, NDIMS> AsEigenDSizes() const;
|
||||
|
||||
/// Same as `AsEigenDSizes()` but allows for `NDIMS > dims()` -- in
|
||||
/// which case we pad the rest of the sizes with 1.
|
||||
template <int NDIMS>
|
||||
Eigen::DSizes<Eigen::DenseIndex, NDIMS> AsEigenDSizesWithPadding() const;
|
||||
|
||||
/// For iterating through the dimensions.
|
||||
TensorShapeIter begin() const;
|
||||
TensorShapeIter end() const;
|
||||
|
||||
/// For error messages.
|
||||
string DebugString() const;
|
||||
|
||||
/// Same as DebugString()
|
||||
string ShortDebugString() const { return DebugString(); }
|
||||
// TODO(irving): Remove, used to be different but isn't now.
|
||||
|
||||
/// Same as `TensorShape(proto).ShortDebugString()` but doesn't crash for
|
||||
/// invalid protos.
|
||||
static string ShortDebugString(const TensorShapeProto& proto);
|
||||
// TODO(irving): Rename to DebugString.
|
||||
|
||||
private:
|
||||
// Recalculates the dimensions of this tensor after they are modified.
|
||||
void recompute_dims();
|
||||
|
||||
// TODO(josh11b): Maybe use something from the Eigen Tensor library
|
||||
// for the sizes.
|
||||
gtl::InlinedVector<int64, 4> dim_sizes_;
|
||||
|
||||
// total number of elements (avoids recomputing it each time).
|
||||
int64 num_elements_;
|
||||
};
|
||||
|
||||
struct TensorShapeDim {
|
||||
explicit TensorShapeDim(int64 s) : size(s) {}
|
||||
int size;
|
||||
};
|
||||
|
||||
class TensorShapeIter {
|
||||
public:
|
||||
TensorShapeIter(const TensorShape* shape, int d) : shape_(shape), d_(d) {}
|
||||
bool operator==(const TensorShapeIter& rhs) {
|
||||
DCHECK(shape_ == rhs.shape_);
|
||||
return d_ == rhs.d_;
|
||||
}
|
||||
bool operator!=(const TensorShapeIter& rhs) {
|
||||
DCHECK(shape_ == rhs.shape_);
|
||||
return d_ != rhs.d_;
|
||||
}
|
||||
void operator++() { ++d_; }
|
||||
TensorShapeDim operator*() { return TensorShapeDim(shape_->dim_size(d_)); }
|
||||
|
||||
private:
|
||||
const TensorShape* shape_;
|
||||
int d_;
|
||||
};
|
||||
|
||||
/// \brief Static helper routines for `TensorShape`. Includes a few common
|
||||
/// predicates on a tensor shape.
|
||||
class TensorShapeUtils {
|
||||
public:
|
||||
static bool IsScalar(const TensorShape& shape) { return shape.dims() == 0; }
|
||||
|
||||
static bool IsVector(const TensorShape& shape) { return shape.dims() == 1; }
|
||||
|
||||
static bool IsVectorOrHigher(const TensorShape& shape) {
|
||||
return shape.dims() >= 1;
|
||||
}
|
||||
|
||||
static bool IsMatrix(const TensorShape& shape) { return shape.dims() == 2; }
|
||||
|
||||
static bool IsMatrixOrHigher(const TensorShape& shape) {
|
||||
return shape.dims() >= 2;
|
||||
}
|
||||
|
||||
/// \brief Returns a `TensorShape` whose dimensions are
|
||||
/// `dims[0]`, `dims[1]`, ..., `dims[n-1]`.
|
||||
template <typename T>
|
||||
static Status MakeShape(const T* dims, int n, TensorShape* out) {
|
||||
*out = TensorShape();
|
||||
for (int i = 0; i < n; ++i) {
|
||||
if (dims[i] >= 0) {
|
||||
out->AddDim(dims[i]);
|
||||
} else {
|
||||
return errors::InvalidArgument("Dimension ", dims[i], " must be >= 0");
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
static string ShapeListString(const gtl::ArraySlice<TensorShape>& shapes) {
|
||||
string result = "[";
|
||||
bool first = true;
|
||||
for (const TensorShape& shape : shapes) {
|
||||
strings::StrAppend(&result, (first ? "" : ", "), shape.DebugString());
|
||||
first = false;
|
||||
}
|
||||
strings::StrAppend(&result, "]");
|
||||
return result;
|
||||
}
|
||||
|
||||
static bool StartsWith(const TensorShape& shape0, const TensorShape& shape1);
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Template method implementation details below
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
template <int NDIMS>
|
||||
Eigen::DSizes<Eigen::DenseIndex, NDIMS> TensorShape::AsEigenDSizes() const {
|
||||
CHECK_EQ(NDIMS, dims()) << "Asking for tensor of " << NDIMS
|
||||
<< " for a tensor of " << dims() << " dimensions";
|
||||
return AsEigenDSizesWithPadding<NDIMS>();
|
||||
}
|
||||
|
||||
template <int NDIMS>
|
||||
Eigen::DSizes<Eigen::DenseIndex, NDIMS> TensorShape::AsEigenDSizesWithPadding()
|
||||
const {
|
||||
CHECK_GE(NDIMS, dims()) << "Asking for tensor of " << NDIMS
|
||||
<< " for a tensor of " << dims() << " dimensions";
|
||||
Eigen::DSizes<Eigen::DenseIndex, NDIMS> dsizes;
|
||||
for (int d = 0; d < dims(); d++) {
|
||||
dsizes[d] = dim_size(d);
|
||||
}
|
||||
for (int d = dims(); d < NDIMS; d++) {
|
||||
dsizes[d] = 1;
|
||||
}
|
||||
return dsizes;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
// This file is deprecated, use ../framework/tensor_shape.h instead.
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
|
||||
#endif // TENSORFLOW_PUBLIC_TENSOR_SHAPE_H_
|
||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/python/client/tf_session_helper.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/public/status.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
|
||||
%}
|
||||
|
@ -21,7 +21,7 @@ limitations under the License.
|
||||
%apply int { tensorflow::error::Code }; // Treat the enum as an integer.
|
||||
|
||||
%{
|
||||
#include "tensorflow/core/public/status.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
%}
|
||||
|
||||
%typemap(out, fragment="StatusNotOK") tensorflow::Status {
|
||||
@ -49,7 +49,7 @@ if (pywrap_status) {
|
||||
%}
|
||||
|
||||
%fragment("StatusNotOK", "header") %{
|
||||
#include "tensorflow/core/public/status.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace {
|
||||
// Initialized on the first call to RaiseStatusNotOK().
|
||||
@ -118,6 +118,6 @@ void RaiseStatusNotOK(const tensorflow::Status& status, swig_type_info *type) {
|
||||
%unignore tensorflow::Status::~Status;
|
||||
%ignore tensorflow::Status::operator=;
|
||||
|
||||
%include "tensorflow/core/public/status.h"
|
||||
%include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
%unignoreall
|
||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
%import(module="tensorflow.python.pywrap_tensorflow") "tensorflow/python/lib/core/status.i"
|
||||
|
||||
%inline %{
|
||||
#include "tensorflow/core/public/status.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
tensorflow::Status NotOkay() {
|
||||
return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT, "Testing 1 2 3");
|
||||
|
Loading…
Reference in New Issue
Block a user