Add TpuExecutorInterface::GetCoreLocationExternal() to TPU API.
PiperOrigin-RevId: 327082333 Change-Id: Ifd130fd49d635a137f407b28fe89243126cef800
This commit is contained in:
parent
6684bae7c6
commit
f8f221e367
@ -96,6 +96,7 @@ tensorflow::Status SetExecutorStructFn(void* library_handle) {
|
||||
TFTPU_SET_FN(executor_fn, TpuExecutor_DeallocateStream);
|
||||
TFTPU_SET_FN(executor_fn, TpuExecutor_CreateStreamDependency);
|
||||
TFTPU_SET_FN(executor_fn, TpuExecutor_GetStatus);
|
||||
TFTPU_SET_FN(executor_fn, TpuExecutor_GetCoreLocation);
|
||||
TFTPU_SET_FN(executor_fn, TpuExecutor_AllocateEvent);
|
||||
TFTPU_SET_FN(executor_fn, TpuExecutor_DeallocateEvent);
|
||||
TFTPU_SET_FN(executor_fn, TpuExecutor_PollForEventStatus);
|
||||
|
@ -308,6 +308,7 @@ cc_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":tpu_platform_interface",
|
||||
":tpu_topology_external",
|
||||
"//tensorflow/stream_executor:stream_executor_headers",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
],
|
||||
|
@ -80,6 +80,11 @@ Status TpuExecutor::GetStatus(Stream* stream) {
|
||||
return status.status();
|
||||
}
|
||||
|
||||
tpu::TpuCoreLocationExternal TpuExecutor::GetCoreLocationExternal() const {
|
||||
return tpu::TpuCoreLocationExternal(
|
||||
tpu::ExecutorApiFn()->TpuExecutor_GetCoreLocationFn(executor_));
|
||||
}
|
||||
|
||||
bool TpuExecutor::AllocateStream(Stream* stream) {
|
||||
return tpu::ExecutorApiFn()->TpuExecutor_AllocateStreamFn(
|
||||
executor_, stream_map().at(stream->implementation()));
|
||||
|
@ -100,6 +100,8 @@ class TpuExecutor : public tensorflow::tpu::TpuExecutorInterface {
|
||||
|
||||
absl::optional<stream_executor::AllocatorStats> GetAllocatorStats() override;
|
||||
|
||||
tpu::TpuCoreLocationExternal GetCoreLocationExternal() const override;
|
||||
|
||||
Status GetStatus(Stream* stream) override;
|
||||
|
||||
std::unique_ptr<::stream_executor::internal::StreamInterface>
|
||||
|
@ -65,6 +65,8 @@ bool TpuExecutor_CreateStreamDependency(SE_StreamExecutor* executor,
|
||||
void TpuExecutor_GetStatus(SE_StreamExecutor* executor, SE_Stream* stream,
|
||||
SE_Status* status);
|
||||
|
||||
SE_TpuTopology_Core* TpuExecutor_GetCoreLocation(SE_StreamExecutor* executor);
|
||||
|
||||
void TpuExecutor_AllocateEvent(SE_StreamExecutor* executor, SE_Event* event,
|
||||
SE_Status* status);
|
||||
void TpuExecutor_DeallocateEvent(SE_StreamExecutor* executor, SE_Event* event,
|
||||
@ -304,6 +306,7 @@ struct TfTpu_ExecutorApiFn {
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_DeallocateStream);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_CreateStreamDependency);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_GetStatus);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_GetCoreLocation);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_AllocateEvent);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_DeallocateEvent);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutor_PollForEventStatus);
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
#include "tensorflow/stream_executor/stream_executor_internal.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_topology.h"
|
||||
|
||||
namespace tpu {
|
||||
class TpuCore;
|
||||
@ -53,6 +54,10 @@ class TpuExecutorInterface
|
||||
}
|
||||
|
||||
virtual TpuPlatformInterface& platform() { LOG(FATAL) << "Unimplemented."; }
|
||||
|
||||
virtual TpuCoreLocationExternal GetCoreLocationExternal() const {
|
||||
LOG(FATAL) << "Unimplemented.";
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tpu
|
||||
|
Loading…
Reference in New Issue
Block a user