eager: Move "register_function" to context.py

This will allow function registration from other
modules without having to import "function.py".
(And besides, the function really does belong on the context).

PiperOrigin-RevId: 168040411
This commit is contained in:
Asim Shankar 2017-09-08 13:46:41 -07:00 committed by TensorFlower Gardener
parent 74137f994f
commit c8b9e92f07
3 changed files with 18 additions and 14 deletions

View File

@ -297,7 +297,6 @@ py_library(
"//tensorflow/python:framework_ops",
"//tensorflow/python:gradients",
"//tensorflow/python:graph_to_function_def",
"//tensorflow/python:pywrap_tensorflow",
"//tensorflow/python:util",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:core",

View File

@ -245,6 +245,23 @@ class Context(object):
# TODO(ashankar): Use TF_DeviceListType to count GPU devices.
return len(self._devices) - 1
def add_function_def(self, fdef):
"""Add a function definition to the context.
Once added, the function (identified by its name) can be executed like any
other operation.
Args:
fdef: A FunctionDef protocol buffer message.
"""
fdef_string = fdef.SerializeToString()
with errors.raise_exception_on_not_ok_status() as status:
pywrap_tensorflow.TFE_ContextAddFunctionDef(
self._handle, # pylint: disable=protected-access
fdef_string,
len(fdef_string),
status)
def add_post_execution_callback(self, callback):
"""Add a post-execution callback to the context.

View File

@ -26,14 +26,12 @@ import threading
from autograd import core as ag_core
import numpy as np
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
from tensorflow.python.eager import execute
from tensorflow.python.eager import tape
from tensorflow.python.eager import tensor
from tensorflow.python.eager.graph_only_ops import graph_placeholder
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import graph_to_function_def
from tensorflow.python.framework import ops
from tensorflow.python.ops import gradients_impl
@ -438,20 +436,10 @@ def _cache_key(x):
return x
def register_function_def(fdef):
fdef_string = fdef.SerializeToString()
with errors.raise_exception_on_not_ok_status() as status:
pywrap_tensorflow.TFE_ContextAddFunctionDef(
context.get_default_context()._handle, # pylint: disable=protected-access
fdef_string,
len(fdef_string),
status)
def _register_with_name(name, fdef):
"""Registers the function `fdef` with the name `name`."""
fdef.signature.name = name
register_function_def(fdef)
context.context().add_function_def(fdef)
# TODO(apassos): better error messages for non-hashable arguments.