Add a functiondef getter to the context
PiperOrigin-RevId: 296002833 Change-Id: I238a2984a9320c084b7157e6eeb30b30aa132036
This commit is contained in:
parent
61b85f68db
commit
fa5cdeae7e
@ -569,3 +569,22 @@ void TFE_TensorHandleEnableImplicitMirroring(TFE_TensorHandle* h,
|
||||
h->handle->EnableImplicitMirroring();
|
||||
status->status = tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name,
|
||||
TF_Buffer* buf, TF_Status* status) {
|
||||
auto* function_def = ctx->context->FindFunctionDef(function_name);
|
||||
if (function_def == nullptr) {
|
||||
status->status = tensorflow::errors::NotFound(
|
||||
"Unable to find FunctionDef with name: ", function_name);
|
||||
return;
|
||||
}
|
||||
string str = function_def->SerializeAsString();
|
||||
void* data = tensorflow::port::Malloc(str.length());
|
||||
str.copy(static_cast<char*>(data), str.length(), 0);
|
||||
buf->data = data;
|
||||
buf->length = str.length();
|
||||
buf->data_deallocator = [](void* data, size_t length) {
|
||||
tensorflow::port::Free(data);
|
||||
};
|
||||
status->status = tensorflow::Status::OK();
|
||||
}
|
||||
|
@ -475,6 +475,11 @@ typedef struct TFE_CustomDevice {
|
||||
void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
|
||||
const char* device_name, void* device_info);
|
||||
|
||||
TF_CAPI_EXPORT extern void TFE_ContextGetFunctionDef(TFE_Context* ctx,
|
||||
const char* function_name,
|
||||
TF_Buffer* buf,
|
||||
TF_Status* status);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
#endif
|
||||
|
@ -622,6 +622,10 @@ Status EagerContext::AddFunctionDef(const FunctionDef& fdef,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const FunctionDef* EagerContext::GetFunctionDef(const string& function_name) {
|
||||
return func_lib_def_.Find(function_name);
|
||||
}
|
||||
|
||||
Status EagerContext::RemoveFunction(const string& func) {
|
||||
bool is_last_ref = false;
|
||||
{
|
||||
|
@ -232,6 +232,8 @@ class EagerContext : public core::RefCounted {
|
||||
const FunctionDefLibrary& library,
|
||||
const bool add_to_local_only = false);
|
||||
|
||||
const FunctionDef* GetFunctionDef(const string& function_name);
|
||||
|
||||
Status RemoveFunction(const string& func);
|
||||
|
||||
// Clear remote executors on all worker targets in `remote_contexts_`.
|
||||
|
@ -28,6 +28,7 @@ from absl import logging
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
from tensorflow.core.framework import function_pb2
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||
from tensorflow.python import pywrap_tfe
|
||||
@ -1054,6 +1055,26 @@ class Context(object):
|
||||
pywrap_tfe.TFE_ContextAddFunctionDef(self._handle, fdef_string,
|
||||
len(fdef_string))
|
||||
|
||||
def get_function_def(self, name):
|
||||
"""Get a function definition from the context.
|
||||
|
||||
Args:
|
||||
name: function signature name.
|
||||
|
||||
Returns:
|
||||
The requested FunctionDef.
|
||||
|
||||
Raises:
|
||||
tf.errors.NotFoundError: if name is not the name of a registered function.
|
||||
"""
|
||||
with c_api_util.tf_buffer() as buffer_:
|
||||
pywrap_tfe.TFE_ContextGetFunctionDef(self._handle, name, buffer_)
|
||||
proto_data = pywrap_tf_session.TF_GetBuffer(buffer_)
|
||||
function_def = function_pb2.FunctionDef()
|
||||
function_def.ParseFromString(proto_data)
|
||||
|
||||
return function_def
|
||||
|
||||
def remove_function(self, name):
|
||||
"""Remove a function from the context.
|
||||
|
||||
@ -2124,6 +2145,10 @@ def remove_function(name):
|
||||
context().remove_function(name)
|
||||
|
||||
|
||||
def get_function_def(name):
|
||||
return context().get_function_def(name)
|
||||
|
||||
|
||||
# Not every user creates a Context via context.context()
|
||||
# (for example, enable_eager_execution in python/framework/ops.py),
|
||||
# but they do all import this file. Note that IS_IN_GRAPH_MODE and
|
||||
|
@ -24,6 +24,7 @@ import numpy as np
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
@ -86,6 +87,27 @@ class ContextTest(test.TestCase):
|
||||
graph, = graphs
|
||||
self.assertIn('CPU:0', graph.node[0].device)
|
||||
|
||||
def testGetFunctionDef(self):
|
||||
|
||||
@def_function.function
|
||||
def f():
|
||||
return constant_op.constant(1.)
|
||||
|
||||
concrete = f.get_concrete_function()
|
||||
function_def = context.get_function_def(concrete.name)
|
||||
|
||||
self.assertIsNot(function_def, None)
|
||||
|
||||
found_const_node = False
|
||||
for node_def in function_def.node_def:
|
||||
if node_def.op == 'Const':
|
||||
found_const_node = True
|
||||
break
|
||||
self.assertTrue(found_const_node)
|
||||
|
||||
with self.assertRaises(errors.NotFoundError):
|
||||
_ = context.get_function_def('this_should_not_be_found')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ops.enable_eager_execution()
|
||||
|
@ -382,6 +382,14 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
||||
status.get());
|
||||
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||
});
|
||||
m.def("TFE_ContextGetFunctionDef",
|
||||
[](py::handle& ctx, const char* function_name, TF_Buffer& buf) {
|
||||
tensorflow::Safe_TF_StatusPtr status =
|
||||
tensorflow::make_safe(TF_NewStatus());
|
||||
TFE_ContextGetFunctionDef(tensorflow::InputTFE_Context(ctx),
|
||||
function_name, &buf, status.get());
|
||||
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||
});
|
||||
m.def("TFE_ContextRemoveFunction", [](py::handle& ctx, const char* name) {
|
||||
tensorflow::Safe_TF_StatusPtr status =
|
||||
tensorflow::make_safe(TF_NewStatus());
|
||||
|
Loading…
Reference in New Issue
Block a user