Add a functiondef getter to the context

PiperOrigin-RevId: 296002833
Change-Id: I238a2984a9320c084b7157e6eeb30b30aa132036
This commit is contained in:
Akshay Modi 2020-02-19 10:40:38 -08:00 committed by TensorFlower Gardener
parent 61b85f68db
commit fa5cdeae7e
7 changed files with 85 additions and 0 deletions

View File

@ -569,3 +569,22 @@ void TFE_TensorHandleEnableImplicitMirroring(TFE_TensorHandle* h,
h->handle->EnableImplicitMirroring();
status->status = tensorflow::Status::OK();
}
void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name,
TF_Buffer* buf, TF_Status* status) {
auto* function_def = ctx->context->FindFunctionDef(function_name);
if (function_def == nullptr) {
status->status = tensorflow::errors::NotFound(
"Unable to find FunctionDef with name: ", function_name);
return;
}
string str = function_def->SerializeAsString();
void* data = tensorflow::port::Malloc(str.length());
str.copy(static_cast<char*>(data), str.length(), 0);
buf->data = data;
buf->length = str.length();
buf->data_deallocator = [](void* data, size_t length) {
tensorflow::port::Free(data);
};
status->status = tensorflow::Status::OK();
}

View File

@ -475,6 +475,11 @@ typedef struct TFE_CustomDevice {
void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
const char* device_name, void* device_info);
TF_CAPI_EXPORT extern void TFE_ContextGetFunctionDef(TFE_Context* ctx,
const char* function_name,
TF_Buffer* buf,
TF_Status* status);
#ifdef __cplusplus
} /* end extern "C" */
#endif

View File

@ -622,6 +622,10 @@ Status EagerContext::AddFunctionDef(const FunctionDef& fdef,
return Status::OK();
}
const FunctionDef* EagerContext::GetFunctionDef(const string& function_name) {
return func_lib_def_.Find(function_name);
}
Status EagerContext::RemoveFunction(const string& func) {
bool is_last_ref = false;
{

View File

@ -232,6 +232,8 @@ class EagerContext : public core::RefCounted {
const FunctionDefLibrary& library,
const bool add_to_local_only = false);
const FunctionDef* GetFunctionDef(const string& function_name);
Status RemoveFunction(const string& func);
// Clear remote executors on all worker targets in `remote_contexts_`.

View File

@ -28,6 +28,7 @@ from absl import logging
import numpy as np
import six
from tensorflow.core.framework import function_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python import pywrap_tfe
@ -1054,6 +1055,26 @@ class Context(object):
pywrap_tfe.TFE_ContextAddFunctionDef(self._handle, fdef_string,
len(fdef_string))
def get_function_def(self, name):
"""Get a function definition from the context.
Args:
name: function signature name.
Returns:
The requested FunctionDef.
Raises:
tf.errors.NotFoundError: if name is not the name of a registered function.
"""
with c_api_util.tf_buffer() as buffer_:
pywrap_tfe.TFE_ContextGetFunctionDef(self._handle, name, buffer_)
proto_data = pywrap_tf_session.TF_GetBuffer(buffer_)
function_def = function_pb2.FunctionDef()
function_def.ParseFromString(proto_data)
return function_def
def remove_function(self, name):
"""Remove a function from the context.
@ -2124,6 +2145,10 @@ def remove_function(name):
context().remove_function(name)
def get_function_def(name):
return context().get_function_def(name)
# Not every user creates a Context via context.context()
# (for example, enable_eager_execution in python/framework/ops.py),
# but they do all import this file. Note that IS_IN_GRAPH_MODE and

View File

@ -24,6 +24,7 @@ import numpy as np
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.platform import test
@ -86,6 +87,27 @@ class ContextTest(test.TestCase):
graph, = graphs
self.assertIn('CPU:0', graph.node[0].device)
def testGetFunctionDef(self):
@def_function.function
def f():
return constant_op.constant(1.)
concrete = f.get_concrete_function()
function_def = context.get_function_def(concrete.name)
self.assertIsNot(function_def, None)
found_const_node = False
for node_def in function_def.node_def:
if node_def.op == 'Const':
found_const_node = True
break
self.assertTrue(found_const_node)
with self.assertRaises(errors.NotFoundError):
_ = context.get_function_def('this_should_not_be_found')
if __name__ == '__main__':
ops.enable_eager_execution()

View File

@ -382,6 +382,14 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
});
m.def("TFE_ContextGetFunctionDef",
[](py::handle& ctx, const char* function_name, TF_Buffer& buf) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
TFE_ContextGetFunctionDef(tensorflow::InputTFE_Context(ctx),
function_name, &buf, status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
});
m.def("TFE_ContextRemoveFunction", [](py::handle& ctx, const char* name) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());