Add an option such that the cached host_value can be discarded
PiperOrigin-RevId: 317315157 Change-Id: I9d7145390a526003069321c7e04794e139a53c09
This commit is contained in:
parent
6f2be48cde
commit
07c54454ee
tensorflow/compiler/xla/pjrt
@ -1077,13 +1077,17 @@ Status PjRtBuffer::CopyToHostAsync() {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
StatusOr<std::shared_ptr<Literal>> PjRtBuffer::ToLiteral() {
|
||||
StatusOr<std::shared_ptr<Literal>> PjRtBuffer::ToLiteral(
|
||||
const bool discard_cached_copy) {
|
||||
tensorflow::profiler::TraceMe traceme("PjRtBuffer::ToLiteral");
|
||||
TF_RETURN_IF_ERROR(CopyToHostAsync());
|
||||
std::shared_ptr<HostValue> host_value;
|
||||
{
|
||||
absl::MutexLock lock(&mu_);
|
||||
host_value = host_value_;
|
||||
if (discard_cached_copy) {
|
||||
host_value_ = nullptr;
|
||||
}
|
||||
}
|
||||
if (host_value == nullptr) {
|
||||
return InvalidArgument("ToLiteral called on invalid buffer");
|
||||
|
@ -478,8 +478,12 @@ class PjRtBuffer {
|
||||
|
||||
// Returns the buffer's value as an XLA Literal. If the value has previously
|
||||
// been prefetched to the host, then returns the prefetched version, otherwise
|
||||
// copies the buffer to the host. Blocks until the value is ready.
|
||||
StatusOr<std::shared_ptr<Literal>> ToLiteral();
|
||||
// copies the buffer to the host. Blocks until the value is ready. If
|
||||
// `discard_cached_copy` is true then buffer will no longer keep hold of a
|
||||
// cached copy of the literal (i.e. The reference to the host value will be
|
||||
// removed.)
|
||||
StatusOr<std::shared_ptr<Literal>> ToLiteral(
|
||||
bool discard_cached_copy = false);
|
||||
|
||||
// Initiates a copy of the buffer to the host. Does not block waiting for
|
||||
// the transfer to complete. The value can be retrieved by a later call to
|
||||
|
Loading…
Reference in New Issue
Block a user