Add TPU Transfer Manager Interface and Noncopyable Buffer support
PiperOrigin-RevId: 325112371 Change-Id: I96a704a62d5acf3305e0328eba621455bb4290bb
This commit is contained in:
parent
76bb55a271
commit
44be60723e
tensorflow/stream_executor/tpu
@ -52,6 +52,18 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "noncopyable_buffer",
|
||||
hdrs = ["noncopyable_buffer.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_node_context_c_api_hdrs",
|
||||
hdrs = ["tpu_node_context_c_api.h"],
|
||||
@ -189,6 +201,16 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_transfer_manager_interface",
|
||||
hdrs = ["tpu_transfer_manager_interface.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":noncopyable_buffer",
|
||||
"//tensorflow/compiler/xla/service:transfer_manager",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_transfer_manager",
|
||||
srcs = ["tpu_transfer_manager_registration.cc"],
|
||||
@ -210,6 +232,7 @@ cc_library(
|
||||
":status_helper",
|
||||
":tpu_executor_base",
|
||||
":tpu_executor_c_api_hdrs",
|
||||
":tpu_transfer_manager_interface",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
|
112
tensorflow/stream_executor/tpu/noncopyable_buffer.h
Normal file
112
tensorflow/stream_executor/tpu/noncopyable_buffer.h
Normal file
@ -0,0 +1,112 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_NONCOPYABLE_BUFFER_H_
|
||||
#define TENSORFLOW_STREAM_EXECUTOR_TPU_NONCOPYABLE_BUFFER_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/base/casts.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
// Uncopyable buffer type with optional ownership of the underlying data. If
|
||||
// data is not owned then ensuring lifetime of the data exceeds the lifetime of
|
||||
// the buffer is the responsibility of the user.
|
||||
class NoncopyableBuffer {
|
||||
public:
|
||||
NoncopyableBuffer() = default;
|
||||
|
||||
// Allocate an owning buffer without initializing the data. Useful when it
|
||||
// will be filled by a subsequent function and want to avoid initialization
|
||||
// cost. Size is specified in number of uint32's.
|
||||
explicit NoncopyableBuffer(size_t size)
|
||||
: data_(new uint32[size]), buf_(data_.get()), size_(size) {}
|
||||
|
||||
// Allocates an owning buffer and initializes it with the specified data. Size
|
||||
// is specified in number of uint32's.
|
||||
NoncopyableBuffer(size_t size, absl::optional<uint32> value)
|
||||
: NoncopyableBuffer(size) {
|
||||
#ifndef MEMORY_SANITIZER
|
||||
if (!value.has_value()) {
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
uint32 v = value.value_or(0);
|
||||
for (int64 i = 0; i < size; ++i) {
|
||||
data_[i] = v;
|
||||
}
|
||||
}
|
||||
|
||||
// Directly use buf pointer without copying it to owning data_. This delays
|
||||
// the memcpy until mutable access is requested. "buf" is not owned by this
|
||||
// data structure, so it is the user's duty to ensure the live range of "buf"
|
||||
// is longer than this data structure.
|
||||
NoncopyableBuffer(const uint8* buf, uint64 size) // Size is in uint8's.
|
||||
: buf_(buf), size_(size / sizeof(uint32)) {
|
||||
CHECK_EQ(size % sizeof(uint32), 0);
|
||||
}
|
||||
NoncopyableBuffer(const uint32* buf, uint64 size) // Size is in uint32's.
|
||||
: buf_(buf), size_(size) {}
|
||||
|
||||
NoncopyableBuffer(const NoncopyableBuffer&) = delete;
|
||||
NoncopyableBuffer(NoncopyableBuffer&&) = default;
|
||||
|
||||
NoncopyableBuffer& operator=(const NoncopyableBuffer&) = delete;
|
||||
NoncopyableBuffer& operator=(NoncopyableBuffer&&) = default;
|
||||
|
||||
// Ensure that the buffer owns the data and returns a mutable view into the
|
||||
// owned data for modification.
|
||||
absl::Span<uint32> mutable_data() {
|
||||
if (data_ == nullptr) {
|
||||
data_.reset(new uint32[size_]);
|
||||
memcpy(data_.get(), buf_, size_ * sizeof(uint32));
|
||||
buf_ = data_.get();
|
||||
}
|
||||
return absl::Span<uint32>(data_.get(), size_);
|
||||
}
|
||||
|
||||
absl::Span<const uint32> const_data() const {
|
||||
return absl::Span<const uint32>(absl::bit_cast<uint32*>(buf_), size_);
|
||||
}
|
||||
// Clone the content to a given buffer.
|
||||
void CloneTo(void* buf) { memcpy(buf, buf_, size_ * sizeof(uint32)); }
|
||||
|
||||
// Return true if data is owned by this buffer (have been copied to `data_`).
|
||||
bool owns_data() const { return data_ != nullptr; }
|
||||
|
||||
// Returns a copy of the object that owns its buffer.
|
||||
NoncopyableBuffer Clone() const {
|
||||
NoncopyableBuffer clone(size_);
|
||||
memcpy(clone.data_.get(), buf_, size_ * sizeof(uint32));
|
||||
return clone;
|
||||
}
|
||||
|
||||
private:
|
||||
// If data_ != nullptr then buf_ == data_.get()
|
||||
std::unique_ptr<uint32[]> data_; // Owning data pointer.
|
||||
const void* buf_; // Non-owning data pointer.
|
||||
uint64 size_; // Size in number of uint32's.
|
||||
};
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_NONCOPYABLE_BUFFER_H_
|
@ -22,10 +22,11 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/shape.h"
|
||||
#include "tensorflow/stream_executor/stream_executor.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_transfer_manager_interface.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class TpuTransferManager : public xla::TransferManager {
|
||||
class TpuTransferManager : public xla::TpuTransferManagerInterface {
|
||||
public:
|
||||
TpuTransferManager();
|
||||
~TpuTransferManager() override;
|
||||
@ -61,6 +62,12 @@ class TpuTransferManager : public xla::TransferManager {
|
||||
LOG(FATAL) << "Not yet implemented";
|
||||
}
|
||||
|
||||
Status TransferBuffersToInfeed(
|
||||
se::StreamExecutor* executor,
|
||||
const std::deque<tensorflow::tpu::NoncopyableBuffer>& buffers) override {
|
||||
LOG(FATAL) << "Not yet implemented.";
|
||||
}
|
||||
|
||||
Status ResetDevices(
|
||||
absl::Span<stream_executor::StreamExecutor* const> executor) override {
|
||||
LOG(FATAL) << "Not yet implemented";
|
||||
|
@ -0,0 +1,34 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_TRANSFER_MANAGER_INTERFACE_H_
|
||||
#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_TRANSFER_MANAGER_INTERFACE_H_
|
||||
|
||||
#include <deque>
|
||||
|
||||
#include "tensorflow/compiler/xla/service/transfer_manager.h"
|
||||
#include "tensorflow/stream_executor/tpu/noncopyable_buffer.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
class TpuTransferManagerInterface : public xla::TransferManager {
|
||||
virtual Status TransferBuffersToInfeed(
|
||||
se::StreamExecutor* executor,
|
||||
const std::deque<tensorflow::tpu::NoncopyableBuffer>& buffers) = 0;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_TRANSFER_MANAGER_INTERFACE_H_
|
Loading…
Reference in New Issue
Block a user