From f4c54a716543a57fbbd0e163312136dc47414b13 Mon Sep 17 00:00:00 2001 From: Frank Chen Date: Tue, 12 May 2020 01:45:25 -0700 Subject: [PATCH] Add TPU Configuration C API PiperOrigin-RevId: 311082982 Change-Id: I18031e1c84d28b37cbf1cdd68372e351d2da476a --- tensorflow/core/tpu/BUILD | 8 ++++ tensorflow/core/tpu/tpu_config_c_api.h | 54 ++++++++++++++++++++++++++ 2 files changed, 62 insertions(+) create mode 100644 tensorflow/core/tpu/tpu_config_c_api.h diff --git a/tensorflow/core/tpu/BUILD b/tensorflow/core/tpu/BUILD index 43b2d93b917..4ea5fc39929 100644 --- a/tensorflow/core/tpu/BUILD +++ b/tensorflow/core/tpu/BUILD @@ -51,3 +51,11 @@ cc_library( "//tensorflow/core:lib", ], ) + +cc_library( + name = "tpu_config_c_api", + hdrs = ["tpu_config_c_api.h"], + deps = [ + "//tensorflow/c:tf_status", + ], +) diff --git a/tensorflow/core/tpu/tpu_config_c_api.h b/tensorflow/core/tpu/tpu_config_c_api.h new file mode 100644 index 00000000000..334a6a19325 --- /dev/null +++ b/tensorflow/core/tpu/tpu_config_c_api.h @@ -0,0 +1,54 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TPU_TPU_CONFIG_C_API_H_ +#define TENSORFLOW_CORE_TPU_TPU_CONFIG_C_API_H_ + +#include + +#include "tensorflow/c/tf_status.h" + +typedef struct TpuSerializedProto TpuSerializedProto; + +extern "C" { + +bool TPUHostInitialized(); + +// TODO(frankchn): Modify API to take in raw values instead of Tensors. +void ConfigureDistributedTpuOp_DoWork(size_t input_size, + TpuSerializedProto** inputs, + TpuSerializedProto* output, + TF_Status* status); + +void WaitForDistributedTpuOp_DoWork(size_t input_size, + TpuSerializedProto** inputs, + TpuSerializedProto* output, + TF_Status* status); + +void ShutdownDistributedTpuOp_DoWork(TF_Status* status); + +void InitializeHostForDistributedTpuOp_DoWork( + size_t input_size, TpuSerializedProto** inputs, + bool enable_whole_mesh_compilations, TpuSerializedProto* output, + TF_Status* status); + +void SetGlobalTPUArrayOp_DoWork(size_t input_size, TpuSerializedProto** inputs, + TF_Status* status); + +void DisconnectDistributedTpuChipsOp_DoWork(TpuSerializedProto* output, + TF_Status* status); +} + +#endif // TENSORFLOW_CORE_TPU_TPU_CONFIG_C_API_H_