Remove deprecated run_as argument for convert.

PiperOrigin-RevId: 260190100
This commit is contained in:
Dan Moldovan 2019-07-26 12:16:43 -07:00 committed by TensorFlower Gardener
parent 7da40d6fb7
commit 3a93e33d54
6 changed files with 7 additions and 109 deletions

View File

@ -40,7 +40,6 @@ from tensorflow.python.autograph.impl.api import AutoGraphError
from tensorflow.python.autograph.impl.api import convert
from tensorflow.python.autograph.impl.api import converted_call
from tensorflow.python.autograph.impl.api import do_not_convert
from tensorflow.python.autograph.impl.api import RunMode
from tensorflow.python.autograph.impl.api import StackTraceMapper
from tensorflow.python.autograph.impl.api import to_code
from tensorflow.python.autograph.impl.api import to_graph
@ -56,7 +55,6 @@ _allowed_symbols = [
'AutoGraphError',
'ConversionOptions',
'Feature',
'RunMode',
'StackTraceMapper',
'convert',
'converted_call',

View File

@ -1,33 +0,0 @@
# Specifying return data type for `py_func` calls
The `py_func` op requires specifying a
[data type](https://www.tensorflow.org/guide/tensors#data_types).
When wrapping a function with `py_func`, for instance using
`@autograph.do_not_convert(run_as=autograph.RunMode.PY_FUNC)`, you have two
options to specify the returned data type:
* explicitly, with a specified `tf.DType` value
* by matching the data type of an input argument, which is then assumed to be
a `Tensor`
Examples:
Specify an explicit data type:
```
def foo(a):
return a + 1
autograph.util.wrap_py_func(f, return_dtypes=[tf.float32])
```
Match the data type of the first argument:
```
def foo(a):
return a + 1
autograph.util.wrap_py_func(
f, return_dtypes=[autograph.utils.py_func.MatchDType(0)])
```

View File

@ -28,7 +28,6 @@ import re
import sys
import textwrap
import traceback
from enum import Enum
# pylint:disable=g-bad-import-order
import six
@ -42,7 +41,6 @@ from tensorflow.python.autograph.pyct import errors
from tensorflow.python.autograph.pyct import inspect_utils
from tensorflow.python.autograph.pyct import origin_info
from tensorflow.python.autograph.utils import ag_logging as logging
from tensorflow.python.autograph.utils import py_func
from tensorflow.python.framework import errors_impl
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
@ -238,20 +236,6 @@ def convert(recursive=False, optional_features=None, force_conversion=True):
return decorator
class RunMode(Enum):
"""Specifies the way a converted function or method should be executed in TF.
Attributes:
* GRAPH: Call this function directly, as-is. This is suitable for functions
that were already designed for TF graphs and contain ops.
* PY_FUNC: Wrap this function into a py_func op. This is suitable for code
that will only run correctly in Python, for example code that renders to
the display, reads keyboard input, etc.
"""
GRAPH = 1
PY_FUNC = 2
def call_with_unspecified_conversion_status(func):
"""Decorator that resets the conversion context to the unspecified status."""
def wrapper(*args, **kwargs):
@ -272,18 +256,11 @@ def do_not_convert_internal(f):
@tf_export('autograph.experimental.do_not_convert')
def do_not_convert(func=None, run_as=RunMode.GRAPH, return_dtypes=None):
def do_not_convert(func=None):
"""Decorator that suppresses the conversion of a function.
See also: docs/pyfunc_dtypes.md
Args:
func: function to decorate.
run_as: RunMode, specifies how to use the function in TensorFlow.
return_dtypes: Optional[Iterable[ Union[tf.DType,
utils.py_func.MatchDType]]], the return data types of the converted
function, if run_as is RunMode.PY_FUNC. Ignored otherwise. May be set to
None if the function has no return values.
Returns:
If `func` is not None, returns a `Callable` which is equivalent to
@ -293,29 +270,12 @@ def do_not_convert(func=None, run_as=RunMode.GRAPH, return_dtypes=None):
above case.
"""
if func is None:
return functools.partial(
do_not_convert,
run_as=run_as,
return_dtypes=return_dtypes)
return do_not_convert
def graph_wrapper(*args, **kwargs):
def wrapper(*args, **kwargs):
with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED):
return func(*args, **kwargs)
def py_func_wrapper(*args, **kwargs):
if kwargs:
raise NotImplementedError('RunMode.PY_FUNC does not yet support kwargs')
# TODO(mdan): Add support for kwargs.
return py_func.wrap_py_func(
func, return_dtypes, args, kwargs, use_dummy_return=not return_dtypes)
if run_as == RunMode.GRAPH:
wrapper = graph_wrapper
elif run_as == RunMode.PY_FUNC:
wrapper = py_func_wrapper
else:
raise ValueError('unknown value for run_as: %s' % run_as)
if inspect.isfunction(func) or inspect.ismethod(func):
wrapper = functools.update_wrapper(wrapper, func)

View File

@ -35,7 +35,6 @@ from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.impl import api
from tensorflow.python.autograph.pyct import inspect_utils
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.autograph.utils import py_func
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import def_function
from tensorflow.python.eager import function
@ -108,11 +107,11 @@ class ApiTest(test.TestCase):
self.assertListEqual([0, 1], self.evaluate(x).tolist())
@test_util.run_deprecated_v1
def test_convert_then_do_not_convert_graph(self):
def test_convert_then_do_not_convert(self):
class TestClass(object):
@api.do_not_convert(run_as=api.RunMode.GRAPH)
@api.do_not_convert
def called_member(self, a):
return tf.negative(a)
@ -128,32 +127,6 @@ class ApiTest(test.TestCase):
constant_op.constant(-2))
self.assertAllEqual((0, 1), self.evaluate(x))
@test_util.run_deprecated_v1
def test_convert_then_do_not_convert_py_func(self):
class TestClass(object):
@api.do_not_convert(
run_as=api.RunMode.PY_FUNC, return_dtypes=py_func.MatchDType(1))
def called_member(self, a):
return np.negative(a)
@api.convert(recursive=True)
def test_method(self, x, s, a):
while tf.reduce_sum(x) > s:
y = self.called_member(a)
# set_shape works around while_loop's limitations.
# TODO(mdan): Allow specifying shapes (or ShapeLike) instead.
y.set_shape(a.shape)
x //= y
return x
tc = TestClass()
x = tc.test_method(
constant_op.constant((2, 4)), constant_op.constant(1),
constant_op.constant(-2))
self.assertAllEqual((0, 1), self.evaluate(x))
@test_util.run_deprecated_v1
def test_decorator_calls_decorated(self):

View File

@ -6,7 +6,7 @@ tf_module {
}
member_method {
name: "do_not_convert"
argspec: "args=[\'func\', \'run_as\', \'return_dtypes\'], varargs=None, keywords=None, defaults=[\'None\', \'RunMode.GRAPH\', \'None\'], "
argspec: "args=[\'func\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "set_loop_options"

View File

@ -6,7 +6,7 @@ tf_module {
}
member_method {
name: "do_not_convert"
argspec: "args=[\'func\', \'run_as\', \'return_dtypes\'], varargs=None, keywords=None, defaults=[\'None\', \'RunMode.GRAPH\', \'None\'], "
argspec: "args=[\'func\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "set_loop_options"