Add TF_DefaultThreadOptions, TF_StartThread and TF_JoinThread.
PiperOrigin-RevId: 224863771
This commit is contained in:
parent
b51d81f87f
commit
4bc66cd75a
@ -159,3 +159,25 @@ TF_CAPI_EXPORT extern uint64_t TF_NowMicros(void) {
|
|||||||
TF_CAPI_EXPORT extern uint64_t TF_NowSeconds(void) {
|
TF_CAPI_EXPORT extern uint64_t TF_NowSeconds(void) {
|
||||||
return ::tensorflow::Env::Default()->NowSeconds();
|
return ::tensorflow::Env::Default()->NowSeconds();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void TF_DefaultThreadOptions(TF_ThreadOptions* options) {
|
||||||
|
options->stack_size = 0;
|
||||||
|
options->guard_size = 0;
|
||||||
|
options->numa_node = -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_Thread* TF_StartThread(const TF_ThreadOptions* options,
|
||||||
|
const char* thread_name, void (*work_func)(void*),
|
||||||
|
void* param) {
|
||||||
|
::tensorflow::ThreadOptions cc_options;
|
||||||
|
cc_options.stack_size = options->stack_size;
|
||||||
|
cc_options.guard_size = options->guard_size;
|
||||||
|
cc_options.numa_node = options->numa_node;
|
||||||
|
return reinterpret_cast<TF_Thread*>(::tensorflow::Env::Default()->StartThread(
|
||||||
|
cc_options, thread_name, [=]() { (*work_func)(param); }));
|
||||||
|
}
|
||||||
|
|
||||||
|
void TF_JoinThread(TF_Thread* thread) {
|
||||||
|
// ::tensorflow::Thread joins on destruction
|
||||||
|
delete reinterpret_cast<::tensorflow::Thread*>(thread);
|
||||||
|
}
|
||||||
|
@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
#ifndef TENSORFLOW_C_ENV_H_
|
#ifndef TENSORFLOW_C_ENV_H_
|
||||||
#define TENSORFLOW_C_ENV_H_
|
#define TENSORFLOW_C_ENV_H_
|
||||||
|
|
||||||
@ -23,6 +26,7 @@ limitations under the License.
|
|||||||
|
|
||||||
struct TF_WritableFileHandle;
|
struct TF_WritableFileHandle;
|
||||||
struct TF_StringStream;
|
struct TF_StringStream;
|
||||||
|
struct TF_Thread;
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
extern "C" {
|
extern "C" {
|
||||||
@ -37,6 +41,20 @@ typedef struct TF_FileStatistics {
|
|||||||
bool is_directory;
|
bool is_directory;
|
||||||
} TF_FileStatistics;
|
} TF_FileStatistics;
|
||||||
|
|
||||||
|
typedef struct TF_ThreadOptions {
|
||||||
|
// Thread stack size to use (in bytes), zero implies that the system default
|
||||||
|
// will be used.
|
||||||
|
size_t stack_size;
|
||||||
|
|
||||||
|
// Guard area size to use near thread stacks to use (in bytes), zero implies
|
||||||
|
// that the system default will be used.
|
||||||
|
size_t guard_size;
|
||||||
|
|
||||||
|
// The NUMA node to use, -1 implies that there should be no NUMA affinity for
|
||||||
|
// this thread.
|
||||||
|
int numa_node;
|
||||||
|
} TF_ThreadOptions;
|
||||||
|
|
||||||
// Creates the specified directory. Typical status code are:
|
// Creates the specified directory. Typical status code are:
|
||||||
// * TF_OK - successfully created the directory
|
// * TF_OK - successfully created the directory
|
||||||
// * TF_ALREADY_EXISTS - directory already exists
|
// * TF_ALREADY_EXISTS - directory already exists
|
||||||
@ -150,6 +168,25 @@ TF_CAPI_EXPORT extern uint64_t TF_NowMicros(void);
|
|||||||
// Returns the number of seconds since the Unix epoch.
|
// Returns the number of seconds since the Unix epoch.
|
||||||
TF_CAPI_EXPORT extern uint64_t TF_NowSeconds(void);
|
TF_CAPI_EXPORT extern uint64_t TF_NowSeconds(void);
|
||||||
|
|
||||||
|
// Populates a TF_ThreadOptions struct with system-default values.
|
||||||
|
TF_CAPI_EXPORT extern void TF_DefaultThreadOptions(TF_ThreadOptions* options);
|
||||||
|
|
||||||
|
// Returns a new thread that is running work_func and is identified
|
||||||
|
// (for debugging/performance-analysis) by thread_name.
|
||||||
|
//
|
||||||
|
// The given param (which may be null) is passed to work_func when the thread
|
||||||
|
// starts. In this way, data may be passed from the thread back to the caller.
|
||||||
|
//
|
||||||
|
// Caller takes ownership of the result and must call TF_JoinThread on it
|
||||||
|
// eventually.
|
||||||
|
TF_CAPI_EXPORT extern TF_Thread* TF_StartThread(const TF_ThreadOptions* options,
|
||||||
|
const char* thread_name,
|
||||||
|
void (*work_func)(void*),
|
||||||
|
void* param);
|
||||||
|
|
||||||
|
// Waits for the given thread to finish execution, then deletes it.
|
||||||
|
TF_CAPI_EXPORT extern void TF_JoinThread(TF_Thread* thread);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/c/c_api.h"
|
#include "tensorflow/c/c_api.h"
|
||||||
#include "tensorflow/core/lib/io/path.h"
|
#include "tensorflow/core/lib/io/path.h"
|
||||||
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
@ -98,3 +99,29 @@ TEST(TestEnv, TestTimeFunctions) {
|
|||||||
ASSERT_GE(TF_NowMicros(), 946684800 * 1e6);
|
ASSERT_GE(TF_NowMicros(), 946684800 * 1e6);
|
||||||
ASSERT_GE(TF_NowNanos(), 946684800 * 1e9);
|
ASSERT_GE(TF_NowNanos(), 946684800 * 1e9);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
struct SomeThreadData {
|
||||||
|
::tensorflow::mutex mu;
|
||||||
|
bool did_work = false;
|
||||||
|
};
|
||||||
|
|
||||||
|
void SomeThreadFunc(void* data) {
|
||||||
|
auto* real_data = static_cast<SomeThreadData*>(data);
|
||||||
|
::tensorflow::mutex_lock l(real_data->mu);
|
||||||
|
real_data->did_work = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
TEST(TestEnv, TestThreads) {
|
||||||
|
TF_ThreadOptions options;
|
||||||
|
TF_DefaultThreadOptions(&options);
|
||||||
|
SomeThreadData data;
|
||||||
|
TF_Thread* thread =
|
||||||
|
TF_StartThread(&options, "SomeThreadName", &SomeThreadFunc, &data);
|
||||||
|
TF_JoinThread(thread);
|
||||||
|
::tensorflow::mutex_lock l(data.mu);
|
||||||
|
ASSERT_TRUE(data.did_work);
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user