eager: Move "register_function" to context.py

This will allow function registration from other
modules without having to import "function.py".
(And besides, the function really does belong on the context).

PiperOrigin-RevId: 168040411
This commit is contained in:
Asim Shankar 2017-09-08 13:46:41 -07:00 committed by TensorFlower Gardener
parent 74137f994f
commit c8b9e92f07
3 changed files with 18 additions and 14 deletions

View File

@ -297,7 +297,6 @@ py_library(
"//tensorflow/python:framework_ops", "//tensorflow/python:framework_ops",
"//tensorflow/python:gradients", "//tensorflow/python:gradients",
"//tensorflow/python:graph_to_function_def", "//tensorflow/python:graph_to_function_def",
"//tensorflow/python:pywrap_tensorflow",
"//tensorflow/python:util", "//tensorflow/python:util",
"//tensorflow/python/eager:context", "//tensorflow/python/eager:context",
"//tensorflow/python/eager:core", "//tensorflow/python/eager:core",

View File

@ -245,6 +245,23 @@ class Context(object):
# TODO(ashankar): Use TF_DeviceListType to count GPU devices. # TODO(ashankar): Use TF_DeviceListType to count GPU devices.
return len(self._devices) - 1 return len(self._devices) - 1
def add_function_def(self, fdef):
"""Add a function definition to the context.
Once added, the function (identified by its name) can be executed like any
other operation.
Args:
fdef: A FunctionDef protocol buffer message.
"""
fdef_string = fdef.SerializeToString()
with errors.raise_exception_on_not_ok_status() as status:
pywrap_tensorflow.TFE_ContextAddFunctionDef(
self._handle, # pylint: disable=protected-access
fdef_string,
len(fdef_string),
status)
def add_post_execution_callback(self, callback): def add_post_execution_callback(self, callback):
"""Add a post-execution callback to the context. """Add a post-execution callback to the context.

View File

@ -26,14 +26,12 @@ import threading
from autograd import core as ag_core from autograd import core as ag_core
import numpy as np import numpy as np
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.eager import execute from tensorflow.python.eager import execute
from tensorflow.python.eager import tape from tensorflow.python.eager import tape
from tensorflow.python.eager import tensor from tensorflow.python.eager import tensor
from tensorflow.python.eager.graph_only_ops import graph_placeholder from tensorflow.python.eager.graph_only_ops import graph_placeholder
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import graph_to_function_def from tensorflow.python.framework import graph_to_function_def
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import gradients_impl
@ -438,20 +436,10 @@ def _cache_key(x):
return x return x
def register_function_def(fdef):
fdef_string = fdef.SerializeToString()
with errors.raise_exception_on_not_ok_status() as status:
pywrap_tensorflow.TFE_ContextAddFunctionDef(
context.get_default_context()._handle, # pylint: disable=protected-access
fdef_string,
len(fdef_string),
status)
def _register_with_name(name, fdef): def _register_with_name(name, fdef):
"""Registers the function `fdef` with the name `name`.""" """Registers the function `fdef` with the name `name`."""
fdef.signature.name = name fdef.signature.name = name
register_function_def(fdef) context.context().add_function_def(fdef)
# TODO(apassos): better error messages for non-hashable arguments. # TODO(apassos): better error messages for non-hashable arguments.