diff --git a/tensorflow/core/framework/tensor.h b/tensorflow/core/framework/tensor.h index 737ae9efa73..9681e99c987 100644 --- a/tensorflow/core/framework/tensor.h +++ b/tensorflow/core/framework/tensor.h @@ -54,6 +54,44 @@ Status MaybeMoveSliceToElement(Tensor* parent, Tensor* element, int64 index); } // namespace batch_util /// @ingroup core + +/// Interface to access the raw ref-counted data buffer. +class TensorBuffer : public core::RefCounted { + public: + explicit TensorBuffer(void* data_ptr) : data_(data_ptr) {} + ~TensorBuffer() override {} + + /// \brief data() points to a memory region of size() bytes. + /// + /// NOTE(mrry): The `data()` method is not virtual for performance reasons. + /// It can be called multiple times when the contents of a `Tensor` are + /// accessed, and so making it non-virtual allows the body to be inlined. + void* data() const { return data_; } + + /// \brief Size (in bytes) of the buffer. + virtual size_t size() const = 0; + + /// \brief If this TensorBuffer is sub-buffer of another TensorBuffer, + /// returns that TensorBuffer. Otherwise, returns this. + virtual TensorBuffer* root_buffer() = 0; + + /// \brief Fills metadata about the allocation into the proto. + virtual void FillAllocationDescription( + AllocationDescription* proto) const = 0; + + /// \brief Helper method to reinterpret the buffer as an array of `T`. + template + T* base() const { + return reinterpret_cast(data()); + } + + /// \brief Whether this TensorBuffer owns the underlying memory. + virtual bool OwnsMemory() const { return true; } + + private: + void* const data_; +}; + /// Represents an n-dimensional array of values. class Tensor { public: @@ -108,6 +146,11 @@ class Tensor { Tensor(Allocator* a, DataType type, const TensorShape& shape, const AllocationAttributes& allocation_attr); + /// \brief Creates a tensor with the input datatype, shape and buf. + /// + /// Acquires a ref on buf that belongs to this Tensor. + Tensor(DataType type, const TensorShape& shape, TensorBuffer* buf); + /// \brief Creates an empty Tensor of the given data type. /// /// Like Tensor(), returns a 1-dimensional, 0-element Tensor with @@ -606,20 +649,16 @@ class Tensor { TensorShape shape_; TensorBuffer* buf_; - friend class DMAHelper; - friend class TensorCApi; - friend class TensorCord; // For access to buf_ - friend class TensorReference; // For access to buf_ - friend class VariableOp; // For access to set_shape - friend class AutoReloadVariableOp; // For access to set_shape - friend class TensorTestHelper; // For access to set_shape - friend class CastOpBase; // For access to set_dtype; + friend class DMAHelper; // For access to buf_. + friend class TensorCApi; // For access to buf_. + friend class TensorReference; // For access to buf_. + friend class VariableOp; // For access to set_shape. + friend class AutoReloadVariableOp; // For access to set_shape. + friend class TensorTestHelper; // For access to set_shape. + friend class CastOpBase; // For access to set_dtype. friend class OpKernelContext; // For access to RefCountIsOne(). friend class ScopedAllocator; // For access to buf_. friend class XlaTensor; // For access to RefCountIsOne(). - friend class XlaTensorBuffer; // For access to the private constructor taking - // the buffer - friend class Var; template friend class AssignVariableOp; // For access to RefCountIsOne(). template @@ -636,11 +675,6 @@ class Tensor { Tensor* parent, Tensor* element, int64 index); // For access to RefCountIsOne(). - // Creates a tensor with the input datatype, shape and buf. - // - // Acquires a ref on buf that belongs to this Tensor. - Tensor(DataType type, const TensorShape& shape, TensorBuffer* buf); - bool CanUseDMA() const; // Only needed by variable op to set the shape of an uninitialized @@ -673,40 +707,6 @@ class Tensor { // START_SKIP_DOXYGEN -// Interface to access the raw ref-counted data buffer. -class TensorBuffer : public core::RefCounted { - public: - explicit TensorBuffer(void* data_ptr) : data_(data_ptr) {} - ~TensorBuffer() override {} - - // data() points to a memory region of size() bytes. - // - // NOTE(mrry): The `data()` method is not virtual for performance reasons. - // It can be called multiple times when the contents of a `Tensor` are - // accessed, and so making it non-virtual allows the body to be inlined. - void* data() const { return data_; } - virtual size_t size() const = 0; - - // If this TensorBuffer is sub-buffer of another TensorBuffer, - // returns that TensorBuffer. Otherwise, returns this. - virtual TensorBuffer* root_buffer() = 0; - - // Fill metadata about the allocation into the proto. - virtual void FillAllocationDescription( - AllocationDescription* proto) const = 0; - - template - T* base() const { - return reinterpret_cast(data()); - } - - // Whether this TensorBuffer owns the underlying memory. - virtual bool OwnsMemory() const { return true; } - - private: - void* const data_; -}; - template T* Tensor::base() const { return buf_ == nullptr ? nullptr : buf_->base();