Fix an issue that start_profiler_server complains 'AssertionError: Context must be initialized first.'
PiperOrigin-RevId: 253093414
This commit is contained in:
parent
bbd5f591af
commit
6c5d79930c
@ -42,10 +42,8 @@ bool TFE_ProfilerIsOk(TFE_Profiler* profiler) {
|
||||
|
||||
void TFE_DeleteProfiler(TFE_Profiler* profiler) { delete profiler; }
|
||||
|
||||
void TFE_ProfilerSerializeToString(TFE_Context* ctx, TFE_Profiler* profiler,
|
||||
TF_Buffer* buf, TF_Status* status) {
|
||||
TFE_ContextAsyncWait(ctx, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
void TFE_ProfilerSerializeToString(TFE_Profiler* profiler, TF_Buffer* buf,
|
||||
TF_Status* status) {
|
||||
string content;
|
||||
status->status = profiler->profiler->SerializeToString(&content);
|
||||
void* data = tensorflow::port::Malloc(content.length());
|
||||
|
@ -40,8 +40,7 @@ TF_CAPI_EXPORT extern void TFE_DeleteProfiler(TFE_Profiler* profiler);
|
||||
|
||||
// The output string is a binary string of tensorflow.tpu.Trace. User can write
|
||||
// the string to file for offline analysis by tensorboard.
|
||||
TF_CAPI_EXPORT extern void TFE_ProfilerSerializeToString(TFE_Context* ctx,
|
||||
TFE_Profiler* profiler,
|
||||
TF_CAPI_EXPORT extern void TFE_ProfilerSerializeToString(TFE_Profiler* profiler,
|
||||
TF_Buffer* buf,
|
||||
TF_Status* status);
|
||||
|
||||
|
@ -72,7 +72,11 @@ void ExecuteWithProfiling(bool async) {
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
ASSERT_EQ(1, num_retvals);
|
||||
TF_Buffer* profiler_result = TF_NewBuffer();
|
||||
TFE_ProfilerSerializeToString(ctx, profiler, profiler_result, status);
|
||||
if (async) {
|
||||
TFE_ContextAsyncWait(ctx, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
}
|
||||
TFE_ProfilerSerializeToString(profiler, profiler_result, status);
|
||||
TFE_DeleteProfiler(profiler);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
profiler::Trace profile_proto;
|
||||
|
@ -71,9 +71,9 @@ def start():
|
||||
with _profiler_lock:
|
||||
if _profiler is not None:
|
||||
raise ProfilerAlreadyRunningError('Another profiler is running.')
|
||||
context.ensure_initialized()
|
||||
profiler_context = pywrap_tensorflow.TFE_NewProfilerContext()
|
||||
if context.default_execution_mode == context.EAGER_MODE:
|
||||
context.ensure_initialized()
|
||||
pywrap_tensorflow.TFE_ProfilerContextSetEagerContext(
|
||||
profiler_context,
|
||||
context.context()._handle) # pylint: disable=protected-access
|
||||
@ -101,9 +101,10 @@ def stop():
|
||||
if _profiler is None:
|
||||
raise ProfilerNotRunningError(
|
||||
'Cannot stop profiling. No profiler is running.')
|
||||
if context.default_execution_mode == context.EAGER_MODE:
|
||||
context.async_wait()
|
||||
with c_api_util.tf_buffer() as buffer_:
|
||||
pywrap_tensorflow.TFE_ProfilerSerializeToString(
|
||||
context.context()._handle, # pylint: disable=protected-access
|
||||
_profiler,
|
||||
buffer_)
|
||||
result = pywrap_tensorflow.TF_GetBuffer(buffer_)
|
||||
@ -162,6 +163,7 @@ def start_profiler_server(port):
|
||||
"""
|
||||
profiler_context = pywrap_tensorflow.TFE_NewProfilerContext()
|
||||
if context.default_execution_mode == context.EAGER_MODE:
|
||||
context.ensure_initialized()
|
||||
pywrap_tensorflow.TFE_ProfilerContextSetEagerContext(
|
||||
profiler_context,
|
||||
context.context()._handle) # pylint: disable=protected-access
|
||||
|
Loading…
Reference in New Issue
Block a user