Fix an issue that start_profiler_server complains 'AssertionError: Context must be initialized first.'

PiperOrigin-RevId: 253093414
This commit is contained in:
Xiao Yu 2019-06-13 13:26:53 -07:00 committed by TensorFlower Gardener
parent bbd5f591af
commit 6c5d79930c
4 changed files with 12 additions and 9 deletions

View File

@ -42,10 +42,8 @@ bool TFE_ProfilerIsOk(TFE_Profiler* profiler) {
void TFE_DeleteProfiler(TFE_Profiler* profiler) { delete profiler; }
void TFE_ProfilerSerializeToString(TFE_Context* ctx, TFE_Profiler* profiler,
TF_Buffer* buf, TF_Status* status) {
TFE_ContextAsyncWait(ctx, status);
if (TF_GetCode(status) != TF_OK) return;
void TFE_ProfilerSerializeToString(TFE_Profiler* profiler, TF_Buffer* buf,
TF_Status* status) {
string content;
status->status = profiler->profiler->SerializeToString(&content);
void* data = tensorflow::port::Malloc(content.length());

View File

@ -40,8 +40,7 @@ TF_CAPI_EXPORT extern void TFE_DeleteProfiler(TFE_Profiler* profiler);
// The output string is a binary string of tensorflow.tpu.Trace. User can write
// the string to file for offline analysis by tensorboard.
TF_CAPI_EXPORT extern void TFE_ProfilerSerializeToString(TFE_Context* ctx,
TFE_Profiler* profiler,
TF_CAPI_EXPORT extern void TFE_ProfilerSerializeToString(TFE_Profiler* profiler,
TF_Buffer* buf,
TF_Status* status);

View File

@ -72,7 +72,11 @@ void ExecuteWithProfiling(bool async) {
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_EQ(1, num_retvals);
TF_Buffer* profiler_result = TF_NewBuffer();
TFE_ProfilerSerializeToString(ctx, profiler, profiler_result, status);
if (async) {
TFE_ContextAsyncWait(ctx, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
}
TFE_ProfilerSerializeToString(profiler, profiler_result, status);
TFE_DeleteProfiler(profiler);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
profiler::Trace profile_proto;

View File

@ -71,9 +71,9 @@ def start():
with _profiler_lock:
if _profiler is not None:
raise ProfilerAlreadyRunningError('Another profiler is running.')
context.ensure_initialized()
profiler_context = pywrap_tensorflow.TFE_NewProfilerContext()
if context.default_execution_mode == context.EAGER_MODE:
context.ensure_initialized()
pywrap_tensorflow.TFE_ProfilerContextSetEagerContext(
profiler_context,
context.context()._handle) # pylint: disable=protected-access
@ -101,9 +101,10 @@ def stop():
if _profiler is None:
raise ProfilerNotRunningError(
'Cannot stop profiling. No profiler is running.')
if context.default_execution_mode == context.EAGER_MODE:
context.async_wait()
with c_api_util.tf_buffer() as buffer_:
pywrap_tensorflow.TFE_ProfilerSerializeToString(
context.context()._handle, # pylint: disable=protected-access
_profiler,
buffer_)
result = pywrap_tensorflow.TF_GetBuffer(buffer_)
@ -162,6 +163,7 @@ def start_profiler_server(port):
"""
profiler_context = pywrap_tensorflow.TFE_NewProfilerContext()
if context.default_execution_mode == context.EAGER_MODE:
context.ensure_initialized()
pywrap_tensorflow.TFE_ProfilerContextSetEagerContext(
profiler_context,
context.context()._handle) # pylint: disable=protected-access