Add an option such that the cached host_value can be discarded

PiperOrigin-RevId: 317315157
Change-Id: I9d7145390a526003069321c7e04794e139a53c09
This commit is contained in:
Tamara Norman 2020-06-19 08:48:51 -07:00 committed by TensorFlower Gardener
parent 6f2be48cde
commit 07c54454ee
2 changed files with 11 additions and 3 deletions
tensorflow/compiler/xla/pjrt

View File

@ -1077,13 +1077,17 @@ Status PjRtBuffer::CopyToHostAsync() {
return Status::OK();
}
StatusOr<std::shared_ptr<Literal>> PjRtBuffer::ToLiteral() {
StatusOr<std::shared_ptr<Literal>> PjRtBuffer::ToLiteral(
const bool discard_cached_copy) {
tensorflow::profiler::TraceMe traceme("PjRtBuffer::ToLiteral");
TF_RETURN_IF_ERROR(CopyToHostAsync());
std::shared_ptr<HostValue> host_value;
{
absl::MutexLock lock(&mu_);
host_value = host_value_;
if (discard_cached_copy) {
host_value_ = nullptr;
}
}
if (host_value == nullptr) {
return InvalidArgument("ToLiteral called on invalid buffer");

View File

@ -478,8 +478,12 @@ class PjRtBuffer {
// Returns the buffer's value as an XLA Literal. If the value has previously
// been prefetched to the host, then returns the prefetched version, otherwise
// copies the buffer to the host. Blocks until the value is ready.
StatusOr<std::shared_ptr<Literal>> ToLiteral();
// copies the buffer to the host. Blocks until the value is ready. If
// `discard_cached_copy` is true then buffer will no longer keep hold of a
// cached copy of the literal (i.e. The reference to the host value will be
// removed.)
StatusOr<std::shared_ptr<Literal>> ToLiteral(
bool discard_cached_copy = false);
// Initiates a copy of the buffer to the host. Does not block waiting for
// the transfer to complete. The value can be retrieved by a later call to