eager: Move "register_function" to context.py
This will allow function registration from other modules without having to import "function.py". (And besides, the function really does belong on the context). PiperOrigin-RevId: 168040411
This commit is contained in:
parent
74137f994f
commit
c8b9e92f07
@ -297,7 +297,6 @@ py_library(
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:gradients",
|
||||
"//tensorflow/python:graph_to_function_def",
|
||||
"//tensorflow/python:pywrap_tensorflow",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:core",
|
||||
|
@ -245,6 +245,23 @@ class Context(object):
|
||||
# TODO(ashankar): Use TF_DeviceListType to count GPU devices.
|
||||
return len(self._devices) - 1
|
||||
|
||||
def add_function_def(self, fdef):
|
||||
"""Add a function definition to the context.
|
||||
|
||||
Once added, the function (identified by its name) can be executed like any
|
||||
other operation.
|
||||
|
||||
Args:
|
||||
fdef: A FunctionDef protocol buffer message.
|
||||
"""
|
||||
fdef_string = fdef.SerializeToString()
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
pywrap_tensorflow.TFE_ContextAddFunctionDef(
|
||||
self._handle, # pylint: disable=protected-access
|
||||
fdef_string,
|
||||
len(fdef_string),
|
||||
status)
|
||||
|
||||
def add_post_execution_callback(self, callback):
|
||||
"""Add a post-execution callback to the context.
|
||||
|
||||
|
@ -26,14 +26,12 @@ import threading
|
||||
from autograd import core as ag_core
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import execute
|
||||
from tensorflow.python.eager import tape
|
||||
from tensorflow.python.eager import tensor
|
||||
from tensorflow.python.eager.graph_only_ops import graph_placeholder
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import graph_to_function_def
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
@ -438,20 +436,10 @@ def _cache_key(x):
|
||||
return x
|
||||
|
||||
|
||||
def register_function_def(fdef):
|
||||
fdef_string = fdef.SerializeToString()
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
pywrap_tensorflow.TFE_ContextAddFunctionDef(
|
||||
context.get_default_context()._handle, # pylint: disable=protected-access
|
||||
fdef_string,
|
||||
len(fdef_string),
|
||||
status)
|
||||
|
||||
|
||||
def _register_with_name(name, fdef):
|
||||
"""Registers the function `fdef` with the name `name`."""
|
||||
fdef.signature.name = name
|
||||
register_function_def(fdef)
|
||||
context.context().add_function_def(fdef)
|
||||
|
||||
|
||||
# TODO(apassos): better error messages for non-hashable arguments.
|
||||
|
Loading…
Reference in New Issue
Block a user