diff --git a/tensorflow/core/tpu/kernels/BUILD b/tensorflow/core/tpu/kernels/BUILD index 08afc821053..96a4c2c2b97 100644 --- a/tensorflow/core/tpu/kernels/BUILD +++ b/tensorflow/core/tpu/kernels/BUILD @@ -956,6 +956,7 @@ cc_library( name = "tpu_ordinal_selector_op", srcs = ["tpu_ordinal_selector_op.cc"], deps = [ + ":tpu_ordinal_selector", "//tensorflow/core:framework", ], alwayslink = 1, @@ -968,3 +969,14 @@ cc_library( "//tensorflow/core:framework", ], ) + +cc_library( + name = "tpu_ordinal_selector", + hdrs = ["tpu_ordinal_selector.h"], + deps = [ + ":tpu_ordinal_selector_interface", + "//tensorflow/core:framework", + "//tensorflow/core/tpu:tpu_api", + "//tensorflow/core/tpu:tpu_ops_c_api_hdrs", + ], +) diff --git a/tensorflow/core/tpu/kernels/tpu_ordinal_selector.h b/tensorflow/core/tpu/kernels/tpu_ordinal_selector.h new file mode 100644 index 00000000000..faf78f97dc4 --- /dev/null +++ b/tensorflow/core/tpu/kernels/tpu_ordinal_selector.h @@ -0,0 +1,58 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_ORDINAL_SELECTOR_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_TPU_ORDINAL_SELECTOR_H_ + +#include "tensorflow/core/tpu/kernels/tpu_ordinal_selector_interface.h" +#include "tensorflow/core/tpu/tpu_api.h" +#include "tensorflow/core/tpu/tpu_ops_c_api.h" + +namespace tensorflow { +namespace tpu { + +// A reserved ID for deferred core selection. Intentionally set at a number +// that is more than the number of cores available in a future system. +constexpr int32 kDeferredCoreSelectionReserved = -8193; + +class TPUOrdinalSelector : TPUOrdinalSelectorInterface { + public: + explicit TPUOrdinalSelector(int num_cores_per_replica = 1) { + OpsApiFn()->TfTpuOrdinalSelector_CreateFn(&ordinal_selector_, + num_cores_per_replica); + } + ~TPUOrdinalSelector() override { + OpsApiFn()->TfTpuOrdinalSelector_DestroyFn(ordinal_selector_); + } + int64 GetOrdinal(absl::optional key, int64_t* req_id) override { + int64 ordinal; + OpsApiFn()->TfTpuOrdinalSelector_GetOrdinalFn(ordinal_selector_, key, + req_id, &ordinal); + return ordinal; + } + void DequeueFromCoreSelector(int32_t device_ordinal, + int64_t req_id) override { + OpsApiFn()->TfTpuOrdinalSelector_DequeueFromCoreSelectorFn( + ordinal_selector_, device_ordinal, req_id); + } + + private: + TfTpuOrdinalSelector* ordinal_selector_; +}; + +} // namespace tpu +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_ORDINAL_SELECTOR_H_ diff --git a/tensorflow/core/tpu/kernels/tpu_ordinal_selector_op.cc b/tensorflow/core/tpu/kernels/tpu_ordinal_selector_op.cc index 13a624b92f7..c6da029d417 100644 --- a/tensorflow/core/tpu/kernels/tpu_ordinal_selector_op.cc +++ b/tensorflow/core/tpu/kernels/tpu_ordinal_selector_op.cc @@ -19,14 +19,11 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/tpu/kernels/tpu_ordinal_selector.h" namespace tensorflow { namespace { -// A reserved ID for deferred core selection. Intentionally set at a number -// that is more than the number of cores available in a future system. -constexpr int32 kDeferredCoreSelectionReserved = -8193; - // TPUOrdinalSelectorOp is a no-op for backward compatibility. The core // selection algorithm happens inside TPUPartitionedCall. class TPUOrdinalSelectorOp : public OpKernel { @@ -37,7 +34,7 @@ class TPUOrdinalSelectorOp : public OpKernel { void Compute(OpKernelContext* ctx) override { Tensor output(DT_INT32, TensorShape({})); - output.flat().setValues({kDeferredCoreSelectionReserved}); + output.flat().setValues({tpu::kDeferredCoreSelectionReserved}); ctx->set_output(0, output); ctx->SetStatus(Status::OK()); }