[XLA] Do not needlessly store wrapped Tensor and ScopedShapedBuffer inside XlaTensor on a heap

Use absl::optional instead of std::unique_ptr to store them inside the class instead.

PiperOrigin-RevId: 316523861
Change-Id: I8f54f64e5661a877b7c9807465983d8132920474
This commit is contained in:
George Karpenkov 2020-06-15 12:43:34 -07:00 committed by TensorFlower Gardener
parent 08cbfe4090
commit 51373058de

View File

@ -55,7 +55,7 @@ class XlaTensor {
// manage the memory for these tensors a ShapedBuffer may be required. // manage the memory for these tensors a ShapedBuffer may be required.
// Return true if this XlaTensor contains a ShapedBuffer. // Return true if this XlaTensor contains a ShapedBuffer.
bool has_shaped_buffer() const { return shaped_buffer_ != nullptr; } bool has_shaped_buffer() const { return shaped_buffer_.has_value(); }
// Return the contained ShapedBuffer. // Return the contained ShapedBuffer.
// REQUIRES: has_shaped_buffer() // REQUIRES: has_shaped_buffer()
const xla::ShapedBuffer& shaped_buffer() const { const xla::ShapedBuffer& shaped_buffer() const {
@ -68,8 +68,7 @@ class XlaTensor {
} }
// Mutates the XlaTensor to set the ShapedBuffer. // Mutates the XlaTensor to set the ShapedBuffer.
void set_shaped_buffer(xla::ScopedShapedBuffer shaped_buffer) { void set_shaped_buffer(xla::ScopedShapedBuffer shaped_buffer) {
shaped_buffer_ = shaped_buffer_ = std::move(shaped_buffer);
absl::make_unique<xla::ScopedShapedBuffer>(std::move(shaped_buffer));
} }
// Some tensors on the device may have known values on the host. We use these // Some tensors on the device may have known values on the host. We use these
@ -77,14 +76,12 @@ class XlaTensor {
// host value already. // host value already.
// Return true if this XlaTensor contains a host tensor. // Return true if this XlaTensor contains a host tensor.
bool has_host_tensor() const { return host_tensor_ != nullptr; } bool has_host_tensor() const { return host_tensor_.has_value(); }
// Return the contained host tensor. // Return the contained host tensor.
// REQUIRES: has_host_tensor() // REQUIRES: has_host_tensor()
const Tensor& host_tensor() const { return *host_tensor_; } const Tensor& host_tensor() const { return *host_tensor_; }
// Sets the contained host tensor. // Sets the contained host tensor.
void set_host_tensor(const Tensor& tensor) { void set_host_tensor(const Tensor& tensor) { host_tensor_.emplace(tensor); }
host_tensor_.reset(new Tensor(tensor));
}
// Adds synchronization events to 'stream' that wait for this tensor to be // Adds synchronization events to 'stream' that wait for this tensor to be
// defined on 'stream'. Does nothing if the tensor is already defined on that // defined on 'stream'. Does nothing if the tensor is already defined on that
@ -111,9 +108,9 @@ class XlaTensor {
private: private:
// The optional contained ShapedBuffer. // The optional contained ShapedBuffer.
std::unique_ptr<xla::ScopedShapedBuffer> shaped_buffer_; absl::optional<xla::ScopedShapedBuffer> shaped_buffer_;
// An optional host tensor value. // An optional host tensor value.
std::unique_ptr<Tensor> host_tensor_; absl::optional<Tensor> host_tensor_;
// An optional event that is triggered when the tensor's content has been // An optional event that is triggered when the tensor's content has been
// defined. If this event is nullptr, it is assumed that the tensor's content // defined. If this event is nullptr, it is assumed that the tensor's content
// is always defined. // is always defined.