Made the Tensor constructor that takes a TensorBuffer public.
PiperOrigin-RevId: 262940594
This commit is contained in:
parent
5daa70bfcf
commit
1cb425058e
@ -54,6 +54,44 @@ Status MaybeMoveSliceToElement(Tensor* parent, Tensor* element, int64 index);
|
||||
} // namespace batch_util
|
||||
|
||||
/// @ingroup core
|
||||
|
||||
/// Interface to access the raw ref-counted data buffer.
|
||||
class TensorBuffer : public core::RefCounted {
|
||||
public:
|
||||
explicit TensorBuffer(void* data_ptr) : data_(data_ptr) {}
|
||||
~TensorBuffer() override {}
|
||||
|
||||
/// \brief data() points to a memory region of size() bytes.
|
||||
///
|
||||
/// NOTE(mrry): The `data()` method is not virtual for performance reasons.
|
||||
/// It can be called multiple times when the contents of a `Tensor` are
|
||||
/// accessed, and so making it non-virtual allows the body to be inlined.
|
||||
void* data() const { return data_; }
|
||||
|
||||
/// \brief Size (in bytes) of the buffer.
|
||||
virtual size_t size() const = 0;
|
||||
|
||||
/// \brief If this TensorBuffer is sub-buffer of another TensorBuffer,
|
||||
/// returns that TensorBuffer. Otherwise, returns this.
|
||||
virtual TensorBuffer* root_buffer() = 0;
|
||||
|
||||
/// \brief Fills metadata about the allocation into the proto.
|
||||
virtual void FillAllocationDescription(
|
||||
AllocationDescription* proto) const = 0;
|
||||
|
||||
/// \brief Helper method to reinterpret the buffer as an array of `T`.
|
||||
template <typename T>
|
||||
T* base() const {
|
||||
return reinterpret_cast<T*>(data());
|
||||
}
|
||||
|
||||
/// \brief Whether this TensorBuffer owns the underlying memory.
|
||||
virtual bool OwnsMemory() const { return true; }
|
||||
|
||||
private:
|
||||
void* const data_;
|
||||
};
|
||||
|
||||
/// Represents an n-dimensional array of values.
|
||||
class Tensor {
|
||||
public:
|
||||
@ -108,6 +146,11 @@ class Tensor {
|
||||
Tensor(Allocator* a, DataType type, const TensorShape& shape,
|
||||
const AllocationAttributes& allocation_attr);
|
||||
|
||||
/// \brief Creates a tensor with the input datatype, shape and buf.
|
||||
///
|
||||
/// Acquires a ref on buf that belongs to this Tensor.
|
||||
Tensor(DataType type, const TensorShape& shape, TensorBuffer* buf);
|
||||
|
||||
/// \brief Creates an empty Tensor of the given data type.
|
||||
///
|
||||
/// Like Tensor(), returns a 1-dimensional, 0-element Tensor with
|
||||
@ -606,20 +649,16 @@ class Tensor {
|
||||
TensorShape shape_;
|
||||
TensorBuffer* buf_;
|
||||
|
||||
friend class DMAHelper;
|
||||
friend class TensorCApi;
|
||||
friend class TensorCord; // For access to buf_
|
||||
friend class TensorReference; // For access to buf_
|
||||
friend class VariableOp; // For access to set_shape
|
||||
friend class AutoReloadVariableOp; // For access to set_shape
|
||||
friend class TensorTestHelper; // For access to set_shape
|
||||
friend class CastOpBase; // For access to set_dtype;
|
||||
friend class DMAHelper; // For access to buf_.
|
||||
friend class TensorCApi; // For access to buf_.
|
||||
friend class TensorReference; // For access to buf_.
|
||||
friend class VariableOp; // For access to set_shape.
|
||||
friend class AutoReloadVariableOp; // For access to set_shape.
|
||||
friend class TensorTestHelper; // For access to set_shape.
|
||||
friend class CastOpBase; // For access to set_dtype.
|
||||
friend class OpKernelContext; // For access to RefCountIsOne().
|
||||
friend class ScopedAllocator; // For access to buf_.
|
||||
friend class XlaTensor; // For access to RefCountIsOne().
|
||||
friend class XlaTensorBuffer; // For access to the private constructor taking
|
||||
// the buffer
|
||||
friend class Var;
|
||||
template <typename Device, typename T>
|
||||
friend class AssignVariableOp; // For access to RefCountIsOne().
|
||||
template <typename Device, typename T>
|
||||
@ -636,11 +675,6 @@ class Tensor {
|
||||
Tensor* parent, Tensor* element,
|
||||
int64 index); // For access to RefCountIsOne().
|
||||
|
||||
// Creates a tensor with the input datatype, shape and buf.
|
||||
//
|
||||
// Acquires a ref on buf that belongs to this Tensor.
|
||||
Tensor(DataType type, const TensorShape& shape, TensorBuffer* buf);
|
||||
|
||||
bool CanUseDMA() const;
|
||||
|
||||
// Only needed by variable op to set the shape of an uninitialized
|
||||
@ -673,40 +707,6 @@ class Tensor {
|
||||
|
||||
// START_SKIP_DOXYGEN
|
||||
|
||||
// Interface to access the raw ref-counted data buffer.
|
||||
class TensorBuffer : public core::RefCounted {
|
||||
public:
|
||||
explicit TensorBuffer(void* data_ptr) : data_(data_ptr) {}
|
||||
~TensorBuffer() override {}
|
||||
|
||||
// data() points to a memory region of size() bytes.
|
||||
//
|
||||
// NOTE(mrry): The `data()` method is not virtual for performance reasons.
|
||||
// It can be called multiple times when the contents of a `Tensor` are
|
||||
// accessed, and so making it non-virtual allows the body to be inlined.
|
||||
void* data() const { return data_; }
|
||||
virtual size_t size() const = 0;
|
||||
|
||||
// If this TensorBuffer is sub-buffer of another TensorBuffer,
|
||||
// returns that TensorBuffer. Otherwise, returns this.
|
||||
virtual TensorBuffer* root_buffer() = 0;
|
||||
|
||||
// Fill metadata about the allocation into the proto.
|
||||
virtual void FillAllocationDescription(
|
||||
AllocationDescription* proto) const = 0;
|
||||
|
||||
template <typename T>
|
||||
T* base() const {
|
||||
return reinterpret_cast<T*>(data());
|
||||
}
|
||||
|
||||
// Whether this TensorBuffer owns the underlying memory.
|
||||
virtual bool OwnsMemory() const { return true; }
|
||||
|
||||
private:
|
||||
void* const data_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
T* Tensor::base() const {
|
||||
return buf_ == nullptr ? nullptr : buf_->base<T>();
|
||||
|
Loading…
x
Reference in New Issue
Block a user