eager: Move "register_function" to context.py
This will allow function registration from other modules without having to import "function.py". (And besides, the function really does belong on the context). PiperOrigin-RevId: 168040411
This commit is contained in:
parent
74137f994f
commit
c8b9e92f07
@ -297,7 +297,6 @@ py_library(
|
|||||||
"//tensorflow/python:framework_ops",
|
"//tensorflow/python:framework_ops",
|
||||||
"//tensorflow/python:gradients",
|
"//tensorflow/python:gradients",
|
||||||
"//tensorflow/python:graph_to_function_def",
|
"//tensorflow/python:graph_to_function_def",
|
||||||
"//tensorflow/python:pywrap_tensorflow",
|
|
||||||
"//tensorflow/python:util",
|
"//tensorflow/python:util",
|
||||||
"//tensorflow/python/eager:context",
|
"//tensorflow/python/eager:context",
|
||||||
"//tensorflow/python/eager:core",
|
"//tensorflow/python/eager:core",
|
||||||
|
@ -245,6 +245,23 @@ class Context(object):
|
|||||||
# TODO(ashankar): Use TF_DeviceListType to count GPU devices.
|
# TODO(ashankar): Use TF_DeviceListType to count GPU devices.
|
||||||
return len(self._devices) - 1
|
return len(self._devices) - 1
|
||||||
|
|
||||||
|
def add_function_def(self, fdef):
|
||||||
|
"""Add a function definition to the context.
|
||||||
|
|
||||||
|
Once added, the function (identified by its name) can be executed like any
|
||||||
|
other operation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fdef: A FunctionDef protocol buffer message.
|
||||||
|
"""
|
||||||
|
fdef_string = fdef.SerializeToString()
|
||||||
|
with errors.raise_exception_on_not_ok_status() as status:
|
||||||
|
pywrap_tensorflow.TFE_ContextAddFunctionDef(
|
||||||
|
self._handle, # pylint: disable=protected-access
|
||||||
|
fdef_string,
|
||||||
|
len(fdef_string),
|
||||||
|
status)
|
||||||
|
|
||||||
def add_post_execution_callback(self, callback):
|
def add_post_execution_callback(self, callback):
|
||||||
"""Add a post-execution callback to the context.
|
"""Add a post-execution callback to the context.
|
||||||
|
|
||||||
|
@ -26,14 +26,12 @@ import threading
|
|||||||
from autograd import core as ag_core
|
from autograd import core as ag_core
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python import pywrap_tensorflow
|
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import execute
|
from tensorflow.python.eager import execute
|
||||||
from tensorflow.python.eager import tape
|
from tensorflow.python.eager import tape
|
||||||
from tensorflow.python.eager import tensor
|
from tensorflow.python.eager import tensor
|
||||||
from tensorflow.python.eager.graph_only_ops import graph_placeholder
|
from tensorflow.python.eager.graph_only_ops import graph_placeholder
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors
|
|
||||||
from tensorflow.python.framework import graph_to_function_def
|
from tensorflow.python.framework import graph_to_function_def
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import gradients_impl
|
from tensorflow.python.ops import gradients_impl
|
||||||
@ -438,20 +436,10 @@ def _cache_key(x):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def register_function_def(fdef):
|
|
||||||
fdef_string = fdef.SerializeToString()
|
|
||||||
with errors.raise_exception_on_not_ok_status() as status:
|
|
||||||
pywrap_tensorflow.TFE_ContextAddFunctionDef(
|
|
||||||
context.get_default_context()._handle, # pylint: disable=protected-access
|
|
||||||
fdef_string,
|
|
||||||
len(fdef_string),
|
|
||||||
status)
|
|
||||||
|
|
||||||
|
|
||||||
def _register_with_name(name, fdef):
|
def _register_with_name(name, fdef):
|
||||||
"""Registers the function `fdef` with the name `name`."""
|
"""Registers the function `fdef` with the name `name`."""
|
||||||
fdef.signature.name = name
|
fdef.signature.name = name
|
||||||
register_function_def(fdef)
|
context.context().add_function_def(fdef)
|
||||||
|
|
||||||
|
|
||||||
# TODO(apassos): better error messages for non-hashable arguments.
|
# TODO(apassos): better error messages for non-hashable arguments.
|
||||||
|
Loading…
Reference in New Issue
Block a user