Consolidate/Merge the fn_args in multiple places (with double partial) and add full tests.

RELNOTES: n/a

PiperOrigin-RevId: 161113304
This commit is contained in:
Jianwei Xie 2017-07-06 12:18:21 -07:00 committed by TensorFlower Gardener
parent 70aa8daacf
commit e7fc162658
5 changed files with 206 additions and 32 deletions

View File

@ -30,6 +30,7 @@ from tensorflow.contrib.tpu.python.tpu import training_loop
from tensorflow.python.estimator import estimator as estimator_lib
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.estimator import util
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@ -303,7 +304,7 @@ class TpuEstimator(estimator_lib.Estimator):
if not self._use_tpu or mode != model_fn_lib.ModeKeys.TRAIN:
return super(TpuEstimator, self)._call_input_fn(input_fn, mode)
input_fn_args = estimator_lib._fn_args(input_fn) # pylint: disable=protected-access
input_fn_args = util.fn_args(input_fn)
config = self.config # a deep copy.
kwargs = {}
if 'params' in input_fn_args:
@ -357,7 +358,7 @@ def _verify_estimator_spec(estimator_spec):
def _call_model_fn(model_fn, features, labels, mode, config, params,
require_params=False):
"""Calls the model_fn with required parameters."""
model_fn_args = estimator_lib._fn_args(model_fn) # pylint: disable=protected-access
model_fn_args = util.fn_args(model_fn)
kwargs = {}
if 'mode' in model_fn_args:
kwargs['mode'] = mode

View File

@ -26,6 +26,7 @@ py_library(
":model_fn",
":parsing_utils",
":run_config",
":util",
"//tensorflow/python:util",
],
)
@ -212,6 +213,27 @@ py_test(
],
)
py_library(
name = "util",
srcs = [
"util.py",
],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:util",
],
)
py_test(
name = "util_test",
srcs = ["util_test.py"],
srcs_version = "PY2AND3",
deps = [
":util",
"//tensorflow/python:client_testlib",
],
)
py_library(
name = "estimator",
srcs = [
@ -222,6 +244,7 @@ py_library(
":export",
":model_fn",
":run_config",
":util",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:client",
"//tensorflow/python:control_flow_ops",

View File

@ -31,6 +31,7 @@ from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session as tf_session
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.estimator import run_config
from tensorflow.python.estimator import util
from tensorflow.python.estimator.export.export import build_all_signature_defs
from tensorflow.python.estimator.export.export import get_timestamped_export_dir
from tensorflow.python.framework import ops
@ -47,7 +48,6 @@ from tensorflow.python.training import monitored_session
from tensorflow.python.training import saver
from tensorflow.python.training import training
from tensorflow.python.util import compat
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
@ -575,7 +575,7 @@ class Estimator(object):
ValueError: if input_fn takes invalid arguments.
"""
del mode # unused
input_fn_args = _fn_args(input_fn)
input_fn_args = util.fn_args(input_fn)
kwargs = {}
if 'params' in input_fn_args:
kwargs['params'] = self.params
@ -598,7 +598,7 @@ class Estimator(object):
Raises:
ValueError: if model_fn returns invalid objects.
"""
model_fn_args = _fn_args(self._model_fn)
model_fn_args = util.fn_args(self._model_fn)
kwargs = {}
if 'labels' in model_fn_args:
kwargs['labels'] = labels
@ -791,35 +791,9 @@ def _get_replica_device_setter(config):
return None
def _fn_args(fn):
"""Get argument names for function-like object.
Args:
fn: Function, or function-like object (e.g., result of `functools.partial`).
Returns:
`tuple` of string argument names.
Raises:
ValueError: if partial function has positionally bound arguments
"""
_, fn = tf_decorator.unwrap(fn)
if hasattr(fn, '__call__') and tf_inspect.ismethod(fn.__call__):
# Handle callables.
return tuple(tf_inspect.getargspec(fn.__call__).args)
if hasattr(fn, 'func') and hasattr(fn, 'keywords') and hasattr(fn, 'args'):
# Handle functools.partial and similar objects.
return tuple([
arg for arg in tf_inspect.getargspec(fn.func).args[len(fn.args):]
if arg not in set(fn.keywords.keys())
])
# Handle function.
return tuple(tf_inspect.getargspec(fn).args)
def _verify_model_fn_args(model_fn, params):
"""Verifies model fn arguments."""
args = set(_fn_args(model_fn))
args = set(util.fn_args(model_fn))
if 'features' not in args:
raise ValueError('model_fn (%s) must include features argument.' % model_fn)
if params is not None and 'params' not in args:

View File

@ -0,0 +1,57 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility to retrieve function args.."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
def fn_args(fn):
"""Get argument names for function-like object.
Args:
fn: Function, or function-like object (e.g., result of `functools.partial`).
Returns:
`tuple` of string argument names.
Raises:
ValueError: if partial function has positionally bound arguments
"""
_, fn = tf_decorator.unwrap(fn)
# Handle callables.
if hasattr(fn, '__call__') and tf_inspect.ismethod(fn.__call__):
return tuple(tf_inspect.getargspec(fn.__call__).args)
# Handle functools.partial and similar objects.
if hasattr(fn, 'func') and hasattr(fn, 'keywords') and hasattr(fn, 'args'):
# Handle nested partial.
original_args = fn_args(fn.func)
if not original_args:
return tuple()
return tuple([
arg for arg in original_args[len(fn.args):]
if arg not in set((fn.keywords or {}).keys())
])
# Handle function.
return tuple(tf_inspect.getargspec(fn).args)

View File

@ -0,0 +1,119 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Estimator related util."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from tensorflow.python.estimator import util
from tensorflow.python.platform import test
class FnArgsTest(test.TestCase):
def test_simple_function(self):
def fn(a, b):
return a + b
self.assertEqual(('a', 'b'), util.fn_args(fn))
def test_callable(self):
class Foo(object):
def __call__(self, a, b):
return a + b
self.assertEqual(('self', 'a', 'b'), util.fn_args(Foo()))
def test_partial_function(self):
expected_test_arg = 123
def fn(a, test_arg):
if test_arg != expected_test_arg:
return ValueError('partial fn does not work correctly')
return a
wrapped_fn = functools.partial(fn, test_arg=123)
self.assertEqual(('a',), util.fn_args(wrapped_fn))
def test_partial_function_with_positional_args(self):
expected_test_arg = 123
def fn(test_arg, a):
if test_arg != expected_test_arg:
return ValueError('partial fn does not work correctly')
return a
wrapped_fn = functools.partial(fn, 123)
self.assertEqual(('a',), util.fn_args(wrapped_fn))
self.assertEqual(3, wrapped_fn(3))
self.assertEqual(3, wrapped_fn(a=3))
def test_double_partial(self):
expected_test_arg1 = 123
expected_test_arg2 = 456
def fn(a, test_arg1, test_arg2):
if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2:
return ValueError('partial does not work correctly')
return a
wrapped_fn = functools.partial(fn, test_arg2=456)
double_wrapped_fn = functools.partial(wrapped_fn, test_arg1=123)
self.assertEqual(('a',), util.fn_args(double_wrapped_fn))
def test_double_partial_with_positional_args_in_outer_layer(self):
expected_test_arg1 = 123
expected_test_arg2 = 456
def fn(test_arg1, a, test_arg2):
if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2:
return ValueError('partial fn does not work correctly')
return a
wrapped_fn = functools.partial(fn, test_arg2=456)
double_wrapped_fn = functools.partial(wrapped_fn, 123)
self.assertEqual(('a',), util.fn_args(double_wrapped_fn))
self.assertEqual(3, double_wrapped_fn(3))
self.assertEqual(3, double_wrapped_fn(a=3))
def test_double_partial_with_positional_args_in_both_layers(self):
expected_test_arg1 = 123
expected_test_arg2 = 456
def fn(test_arg1, test_arg2, a):
if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2:
return ValueError('partial fn does not work correctly')
return a
wrapped_fn = functools.partial(fn, 123) # binds to test_arg1
double_wrapped_fn = functools.partial(wrapped_fn, 456) # binds to test_arg2
self.assertEqual(('a',), util.fn_args(double_wrapped_fn))
self.assertEqual(3, double_wrapped_fn(3))
self.assertEqual(3, double_wrapped_fn(a=3))
if __name__ == '__main__':
test.main()