Consolidate/Merge the fn_args in multiple places (with double partial) and add full tests.
RELNOTES: n/a PiperOrigin-RevId: 161113304
This commit is contained in:
parent
70aa8daacf
commit
e7fc162658
@ -30,6 +30,7 @@ from tensorflow.contrib.tpu.python.tpu import training_loop
|
||||
|
||||
from tensorflow.python.estimator import estimator as estimator_lib
|
||||
from tensorflow.python.estimator import model_fn as model_fn_lib
|
||||
from tensorflow.python.estimator import util
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -303,7 +304,7 @@ class TpuEstimator(estimator_lib.Estimator):
|
||||
if not self._use_tpu or mode != model_fn_lib.ModeKeys.TRAIN:
|
||||
return super(TpuEstimator, self)._call_input_fn(input_fn, mode)
|
||||
|
||||
input_fn_args = estimator_lib._fn_args(input_fn) # pylint: disable=protected-access
|
||||
input_fn_args = util.fn_args(input_fn)
|
||||
config = self.config # a deep copy.
|
||||
kwargs = {}
|
||||
if 'params' in input_fn_args:
|
||||
@ -357,7 +358,7 @@ def _verify_estimator_spec(estimator_spec):
|
||||
def _call_model_fn(model_fn, features, labels, mode, config, params,
|
||||
require_params=False):
|
||||
"""Calls the model_fn with required parameters."""
|
||||
model_fn_args = estimator_lib._fn_args(model_fn) # pylint: disable=protected-access
|
||||
model_fn_args = util.fn_args(model_fn)
|
||||
kwargs = {}
|
||||
if 'mode' in model_fn_args:
|
||||
kwargs['mode'] = mode
|
||||
|
@ -26,6 +26,7 @@ py_library(
|
||||
":model_fn",
|
||||
":parsing_utils",
|
||||
":run_config",
|
||||
":util",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
@ -212,6 +213,27 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "util",
|
||||
srcs = [
|
||||
"util.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "util_test",
|
||||
srcs = ["util_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":util",
|
||||
"//tensorflow/python:client_testlib",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "estimator",
|
||||
srcs = [
|
||||
@ -222,6 +244,7 @@ py_library(
|
||||
":export",
|
||||
":model_fn",
|
||||
":run_config",
|
||||
":util",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:client",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
|
@ -31,6 +31,7 @@ from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.client import session as tf_session
|
||||
from tensorflow.python.estimator import model_fn as model_fn_lib
|
||||
from tensorflow.python.estimator import run_config
|
||||
from tensorflow.python.estimator import util
|
||||
from tensorflow.python.estimator.export.export import build_all_signature_defs
|
||||
from tensorflow.python.estimator.export.export import get_timestamped_export_dir
|
||||
from tensorflow.python.framework import ops
|
||||
@ -47,7 +48,6 @@ from tensorflow.python.training import monitored_session
|
||||
from tensorflow.python.training import saver
|
||||
from tensorflow.python.training import training
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util import tf_decorator
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
|
||||
@ -575,7 +575,7 @@ class Estimator(object):
|
||||
ValueError: if input_fn takes invalid arguments.
|
||||
"""
|
||||
del mode # unused
|
||||
input_fn_args = _fn_args(input_fn)
|
||||
input_fn_args = util.fn_args(input_fn)
|
||||
kwargs = {}
|
||||
if 'params' in input_fn_args:
|
||||
kwargs['params'] = self.params
|
||||
@ -598,7 +598,7 @@ class Estimator(object):
|
||||
Raises:
|
||||
ValueError: if model_fn returns invalid objects.
|
||||
"""
|
||||
model_fn_args = _fn_args(self._model_fn)
|
||||
model_fn_args = util.fn_args(self._model_fn)
|
||||
kwargs = {}
|
||||
if 'labels' in model_fn_args:
|
||||
kwargs['labels'] = labels
|
||||
@ -791,35 +791,9 @@ def _get_replica_device_setter(config):
|
||||
return None
|
||||
|
||||
|
||||
def _fn_args(fn):
|
||||
"""Get argument names for function-like object.
|
||||
|
||||
Args:
|
||||
fn: Function, or function-like object (e.g., result of `functools.partial`).
|
||||
|
||||
Returns:
|
||||
`tuple` of string argument names.
|
||||
|
||||
Raises:
|
||||
ValueError: if partial function has positionally bound arguments
|
||||
"""
|
||||
_, fn = tf_decorator.unwrap(fn)
|
||||
if hasattr(fn, '__call__') and tf_inspect.ismethod(fn.__call__):
|
||||
# Handle callables.
|
||||
return tuple(tf_inspect.getargspec(fn.__call__).args)
|
||||
if hasattr(fn, 'func') and hasattr(fn, 'keywords') and hasattr(fn, 'args'):
|
||||
# Handle functools.partial and similar objects.
|
||||
return tuple([
|
||||
arg for arg in tf_inspect.getargspec(fn.func).args[len(fn.args):]
|
||||
if arg not in set(fn.keywords.keys())
|
||||
])
|
||||
# Handle function.
|
||||
return tuple(tf_inspect.getargspec(fn).args)
|
||||
|
||||
|
||||
def _verify_model_fn_args(model_fn, params):
|
||||
"""Verifies model fn arguments."""
|
||||
args = set(_fn_args(model_fn))
|
||||
args = set(util.fn_args(model_fn))
|
||||
if 'features' not in args:
|
||||
raise ValueError('model_fn (%s) must include features argument.' % model_fn)
|
||||
if params is not None and 'params' not in args:
|
||||
|
57
tensorflow/python/estimator/util.py
Normal file
57
tensorflow/python/estimator/util.py
Normal file
@ -0,0 +1,57 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Utility to retrieve function args.."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.util import tf_decorator
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
|
||||
def fn_args(fn):
|
||||
"""Get argument names for function-like object.
|
||||
|
||||
Args:
|
||||
fn: Function, or function-like object (e.g., result of `functools.partial`).
|
||||
|
||||
Returns:
|
||||
`tuple` of string argument names.
|
||||
|
||||
Raises:
|
||||
ValueError: if partial function has positionally bound arguments
|
||||
"""
|
||||
_, fn = tf_decorator.unwrap(fn)
|
||||
|
||||
# Handle callables.
|
||||
if hasattr(fn, '__call__') and tf_inspect.ismethod(fn.__call__):
|
||||
return tuple(tf_inspect.getargspec(fn.__call__).args)
|
||||
|
||||
# Handle functools.partial and similar objects.
|
||||
if hasattr(fn, 'func') and hasattr(fn, 'keywords') and hasattr(fn, 'args'):
|
||||
# Handle nested partial.
|
||||
original_args = fn_args(fn.func)
|
||||
if not original_args:
|
||||
return tuple()
|
||||
|
||||
return tuple([
|
||||
arg for arg in original_args[len(fn.args):]
|
||||
if arg not in set((fn.keywords or {}).keys())
|
||||
])
|
||||
|
||||
# Handle function.
|
||||
return tuple(tf_inspect.getargspec(fn).args)
|
119
tensorflow/python/estimator/util_test.py
Normal file
119
tensorflow/python/estimator/util_test.py
Normal file
@ -0,0 +1,119 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for Estimator related util."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
|
||||
from tensorflow.python.estimator import util
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class FnArgsTest(test.TestCase):
|
||||
|
||||
def test_simple_function(self):
|
||||
def fn(a, b):
|
||||
return a + b
|
||||
self.assertEqual(('a', 'b'), util.fn_args(fn))
|
||||
|
||||
def test_callable(self):
|
||||
|
||||
class Foo(object):
|
||||
|
||||
def __call__(self, a, b):
|
||||
return a + b
|
||||
|
||||
self.assertEqual(('self', 'a', 'b'), util.fn_args(Foo()))
|
||||
|
||||
def test_partial_function(self):
|
||||
expected_test_arg = 123
|
||||
|
||||
def fn(a, test_arg):
|
||||
if test_arg != expected_test_arg:
|
||||
return ValueError('partial fn does not work correctly')
|
||||
return a
|
||||
|
||||
wrapped_fn = functools.partial(fn, test_arg=123)
|
||||
|
||||
self.assertEqual(('a',), util.fn_args(wrapped_fn))
|
||||
|
||||
def test_partial_function_with_positional_args(self):
|
||||
expected_test_arg = 123
|
||||
|
||||
def fn(test_arg, a):
|
||||
if test_arg != expected_test_arg:
|
||||
return ValueError('partial fn does not work correctly')
|
||||
return a
|
||||
|
||||
wrapped_fn = functools.partial(fn, 123)
|
||||
|
||||
self.assertEqual(('a',), util.fn_args(wrapped_fn))
|
||||
|
||||
self.assertEqual(3, wrapped_fn(3))
|
||||
self.assertEqual(3, wrapped_fn(a=3))
|
||||
|
||||
def test_double_partial(self):
|
||||
expected_test_arg1 = 123
|
||||
expected_test_arg2 = 456
|
||||
|
||||
def fn(a, test_arg1, test_arg2):
|
||||
if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2:
|
||||
return ValueError('partial does not work correctly')
|
||||
return a
|
||||
|
||||
wrapped_fn = functools.partial(fn, test_arg2=456)
|
||||
double_wrapped_fn = functools.partial(wrapped_fn, test_arg1=123)
|
||||
|
||||
self.assertEqual(('a',), util.fn_args(double_wrapped_fn))
|
||||
|
||||
def test_double_partial_with_positional_args_in_outer_layer(self):
|
||||
expected_test_arg1 = 123
|
||||
expected_test_arg2 = 456
|
||||
|
||||
def fn(test_arg1, a, test_arg2):
|
||||
if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2:
|
||||
return ValueError('partial fn does not work correctly')
|
||||
return a
|
||||
|
||||
wrapped_fn = functools.partial(fn, test_arg2=456)
|
||||
double_wrapped_fn = functools.partial(wrapped_fn, 123)
|
||||
|
||||
self.assertEqual(('a',), util.fn_args(double_wrapped_fn))
|
||||
|
||||
self.assertEqual(3, double_wrapped_fn(3))
|
||||
self.assertEqual(3, double_wrapped_fn(a=3))
|
||||
|
||||
def test_double_partial_with_positional_args_in_both_layers(self):
|
||||
expected_test_arg1 = 123
|
||||
expected_test_arg2 = 456
|
||||
|
||||
def fn(test_arg1, test_arg2, a):
|
||||
if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2:
|
||||
return ValueError('partial fn does not work correctly')
|
||||
return a
|
||||
|
||||
wrapped_fn = functools.partial(fn, 123) # binds to test_arg1
|
||||
double_wrapped_fn = functools.partial(wrapped_fn, 456) # binds to test_arg2
|
||||
|
||||
self.assertEqual(('a',), util.fn_args(double_wrapped_fn))
|
||||
|
||||
self.assertEqual(3, double_wrapped_fn(3))
|
||||
self.assertEqual(3, double_wrapped_fn(a=3))
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
Loading…
Reference in New Issue
Block a user