Implement OSS TPU ordinal selector.
PiperOrigin-RevId: 360295595 Change-Id: I70168d92e735e5f42c46b0b7742195649b094433
This commit is contained in:
parent
1038045628
commit
23466d6a7b
@ -956,6 +956,7 @@ cc_library(
|
|||||||
name = "tpu_ordinal_selector_op",
|
name = "tpu_ordinal_selector_op",
|
||||||
srcs = ["tpu_ordinal_selector_op.cc"],
|
srcs = ["tpu_ordinal_selector_op.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":tpu_ordinal_selector",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
@ -968,3 +969,14 @@ cc_library(
|
|||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "tpu_ordinal_selector",
|
||||||
|
hdrs = ["tpu_ordinal_selector.h"],
|
||||||
|
deps = [
|
||||||
|
":tpu_ordinal_selector_interface",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core/tpu:tpu_api",
|
||||||
|
"//tensorflow/core/tpu:tpu_ops_c_api_hdrs",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
58
tensorflow/core/tpu/kernels/tpu_ordinal_selector.h
Normal file
58
tensorflow/core/tpu/kernels/tpu_ordinal_selector.h
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_ORDINAL_SELECTOR_H_
|
||||||
|
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_ORDINAL_SELECTOR_H_
|
||||||
|
|
||||||
|
#include "tensorflow/core/tpu/kernels/tpu_ordinal_selector_interface.h"
|
||||||
|
#include "tensorflow/core/tpu/tpu_api.h"
|
||||||
|
#include "tensorflow/core/tpu/tpu_ops_c_api.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace tpu {
|
||||||
|
|
||||||
|
// A reserved ID for deferred core selection. Intentionally set at a number
|
||||||
|
// that is more than the number of cores available in a future system.
|
||||||
|
constexpr int32 kDeferredCoreSelectionReserved = -8193;
|
||||||
|
|
||||||
|
class TPUOrdinalSelector : TPUOrdinalSelectorInterface {
|
||||||
|
public:
|
||||||
|
explicit TPUOrdinalSelector(int num_cores_per_replica = 1) {
|
||||||
|
OpsApiFn()->TfTpuOrdinalSelector_CreateFn(&ordinal_selector_,
|
||||||
|
num_cores_per_replica);
|
||||||
|
}
|
||||||
|
~TPUOrdinalSelector() override {
|
||||||
|
OpsApiFn()->TfTpuOrdinalSelector_DestroyFn(ordinal_selector_);
|
||||||
|
}
|
||||||
|
int64 GetOrdinal(absl::optional<uint64> key, int64_t* req_id) override {
|
||||||
|
int64 ordinal;
|
||||||
|
OpsApiFn()->TfTpuOrdinalSelector_GetOrdinalFn(ordinal_selector_, key,
|
||||||
|
req_id, &ordinal);
|
||||||
|
return ordinal;
|
||||||
|
}
|
||||||
|
void DequeueFromCoreSelector(int32_t device_ordinal,
|
||||||
|
int64_t req_id) override {
|
||||||
|
OpsApiFn()->TfTpuOrdinalSelector_DequeueFromCoreSelectorFn(
|
||||||
|
ordinal_selector_, device_ordinal, req_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
TfTpuOrdinalSelector* ordinal_selector_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tpu
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_ORDINAL_SELECTOR_H_
|
@ -19,14 +19,11 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
|
#include "tensorflow/core/tpu/kernels/tpu_ordinal_selector.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// A reserved ID for deferred core selection. Intentionally set at a number
|
|
||||||
// that is more than the number of cores available in a future system.
|
|
||||||
constexpr int32 kDeferredCoreSelectionReserved = -8193;
|
|
||||||
|
|
||||||
// TPUOrdinalSelectorOp is a no-op for backward compatibility. The core
|
// TPUOrdinalSelectorOp is a no-op for backward compatibility. The core
|
||||||
// selection algorithm happens inside TPUPartitionedCall.
|
// selection algorithm happens inside TPUPartitionedCall.
|
||||||
class TPUOrdinalSelectorOp : public OpKernel {
|
class TPUOrdinalSelectorOp : public OpKernel {
|
||||||
@ -37,7 +34,7 @@ class TPUOrdinalSelectorOp : public OpKernel {
|
|||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override {
|
void Compute(OpKernelContext* ctx) override {
|
||||||
Tensor output(DT_INT32, TensorShape({}));
|
Tensor output(DT_INT32, TensorShape({}));
|
||||||
output.flat<int>().setValues({kDeferredCoreSelectionReserved});
|
output.flat<int>().setValues({tpu::kDeferredCoreSelectionReserved});
|
||||||
ctx->set_output(0, output);
|
ctx->set_output(0, output);
|
||||||
ctx->SetStatus(Status::OK());
|
ctx->SetStatus(Status::OK());
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user