Add TF_DefaultThreadOptions, TF_StartThread and TF_JoinThread.
PiperOrigin-RevId: 224863771
This commit is contained in:
parent
b51d81f87f
commit
4bc66cd75a
@ -159,3 +159,25 @@ TF_CAPI_EXPORT extern uint64_t TF_NowMicros(void) {
|
||||
TF_CAPI_EXPORT extern uint64_t TF_NowSeconds(void) {
|
||||
return ::tensorflow::Env::Default()->NowSeconds();
|
||||
}
|
||||
|
||||
void TF_DefaultThreadOptions(TF_ThreadOptions* options) {
|
||||
options->stack_size = 0;
|
||||
options->guard_size = 0;
|
||||
options->numa_node = -1;
|
||||
}
|
||||
|
||||
TF_Thread* TF_StartThread(const TF_ThreadOptions* options,
|
||||
const char* thread_name, void (*work_func)(void*),
|
||||
void* param) {
|
||||
::tensorflow::ThreadOptions cc_options;
|
||||
cc_options.stack_size = options->stack_size;
|
||||
cc_options.guard_size = options->guard_size;
|
||||
cc_options.numa_node = options->numa_node;
|
||||
return reinterpret_cast<TF_Thread*>(::tensorflow::Env::Default()->StartThread(
|
||||
cc_options, thread_name, [=]() { (*work_func)(param); }));
|
||||
}
|
||||
|
||||
void TF_JoinThread(TF_Thread* thread) {
|
||||
// ::tensorflow::Thread joins on destruction
|
||||
delete reinterpret_cast<::tensorflow::Thread*>(thread);
|
||||
}
|
||||
|
@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#ifndef TENSORFLOW_C_ENV_H_
|
||||
#define TENSORFLOW_C_ENV_H_
|
||||
|
||||
@ -23,6 +26,7 @@ limitations under the License.
|
||||
|
||||
struct TF_WritableFileHandle;
|
||||
struct TF_StringStream;
|
||||
struct TF_Thread;
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
@ -37,6 +41,20 @@ typedef struct TF_FileStatistics {
|
||||
bool is_directory;
|
||||
} TF_FileStatistics;
|
||||
|
||||
typedef struct TF_ThreadOptions {
|
||||
// Thread stack size to use (in bytes), zero implies that the system default
|
||||
// will be used.
|
||||
size_t stack_size;
|
||||
|
||||
// Guard area size to use near thread stacks to use (in bytes), zero implies
|
||||
// that the system default will be used.
|
||||
size_t guard_size;
|
||||
|
||||
// The NUMA node to use, -1 implies that there should be no NUMA affinity for
|
||||
// this thread.
|
||||
int numa_node;
|
||||
} TF_ThreadOptions;
|
||||
|
||||
// Creates the specified directory. Typical status code are:
|
||||
// * TF_OK - successfully created the directory
|
||||
// * TF_ALREADY_EXISTS - directory already exists
|
||||
@ -150,6 +168,25 @@ TF_CAPI_EXPORT extern uint64_t TF_NowMicros(void);
|
||||
// Returns the number of seconds since the Unix epoch.
|
||||
TF_CAPI_EXPORT extern uint64_t TF_NowSeconds(void);
|
||||
|
||||
// Populates a TF_ThreadOptions struct with system-default values.
|
||||
TF_CAPI_EXPORT extern void TF_DefaultThreadOptions(TF_ThreadOptions* options);
|
||||
|
||||
// Returns a new thread that is running work_func and is identified
|
||||
// (for debugging/performance-analysis) by thread_name.
|
||||
//
|
||||
// The given param (which may be null) is passed to work_func when the thread
|
||||
// starts. In this way, data may be passed from the thread back to the caller.
|
||||
//
|
||||
// Caller takes ownership of the result and must call TF_JoinThread on it
|
||||
// eventually.
|
||||
TF_CAPI_EXPORT extern TF_Thread* TF_StartThread(const TF_ThreadOptions* options,
|
||||
const char* thread_name,
|
||||
void (*work_func)(void*),
|
||||
void* param);
|
||||
|
||||
// Waits for the given thread to finish execution, then deletes it.
|
||||
TF_CAPI_EXPORT extern void TF_JoinThread(TF_Thread* thread);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
@ -98,3 +99,29 @@ TEST(TestEnv, TestTimeFunctions) {
|
||||
ASSERT_GE(TF_NowMicros(), 946684800 * 1e6);
|
||||
ASSERT_GE(TF_NowNanos(), 946684800 * 1e9);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
struct SomeThreadData {
|
||||
::tensorflow::mutex mu;
|
||||
bool did_work = false;
|
||||
};
|
||||
|
||||
void SomeThreadFunc(void* data) {
|
||||
auto* real_data = static_cast<SomeThreadData*>(data);
|
||||
::tensorflow::mutex_lock l(real_data->mu);
|
||||
real_data->did_work = true;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TEST(TestEnv, TestThreads) {
|
||||
TF_ThreadOptions options;
|
||||
TF_DefaultThreadOptions(&options);
|
||||
SomeThreadData data;
|
||||
TF_Thread* thread =
|
||||
TF_StartThread(&options, "SomeThreadName", &SomeThreadFunc, &data);
|
||||
TF_JoinThread(thread);
|
||||
::tensorflow::mutex_lock l(data.mu);
|
||||
ASSERT_TRUE(data.did_work);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user