Consolidate/Merge the fn_args in multiple places (with double partial) and add full tests.
RELNOTES: n/a PiperOrigin-RevId: 161113304
This commit is contained in:
parent
70aa8daacf
commit
e7fc162658
@ -30,6 +30,7 @@ from tensorflow.contrib.tpu.python.tpu import training_loop
|
|||||||
|
|
||||||
from tensorflow.python.estimator import estimator as estimator_lib
|
from tensorflow.python.estimator import estimator as estimator_lib
|
||||||
from tensorflow.python.estimator import model_fn as model_fn_lib
|
from tensorflow.python.estimator import model_fn as model_fn_lib
|
||||||
|
from tensorflow.python.estimator import util
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
@ -303,7 +304,7 @@ class TpuEstimator(estimator_lib.Estimator):
|
|||||||
if not self._use_tpu or mode != model_fn_lib.ModeKeys.TRAIN:
|
if not self._use_tpu or mode != model_fn_lib.ModeKeys.TRAIN:
|
||||||
return super(TpuEstimator, self)._call_input_fn(input_fn, mode)
|
return super(TpuEstimator, self)._call_input_fn(input_fn, mode)
|
||||||
|
|
||||||
input_fn_args = estimator_lib._fn_args(input_fn) # pylint: disable=protected-access
|
input_fn_args = util.fn_args(input_fn)
|
||||||
config = self.config # a deep copy.
|
config = self.config # a deep copy.
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if 'params' in input_fn_args:
|
if 'params' in input_fn_args:
|
||||||
@ -357,7 +358,7 @@ def _verify_estimator_spec(estimator_spec):
|
|||||||
def _call_model_fn(model_fn, features, labels, mode, config, params,
|
def _call_model_fn(model_fn, features, labels, mode, config, params,
|
||||||
require_params=False):
|
require_params=False):
|
||||||
"""Calls the model_fn with required parameters."""
|
"""Calls the model_fn with required parameters."""
|
||||||
model_fn_args = estimator_lib._fn_args(model_fn) # pylint: disable=protected-access
|
model_fn_args = util.fn_args(model_fn)
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if 'mode' in model_fn_args:
|
if 'mode' in model_fn_args:
|
||||||
kwargs['mode'] = mode
|
kwargs['mode'] = mode
|
||||||
|
@ -26,6 +26,7 @@ py_library(
|
|||||||
":model_fn",
|
":model_fn",
|
||||||
":parsing_utils",
|
":parsing_utils",
|
||||||
":run_config",
|
":run_config",
|
||||||
|
":util",
|
||||||
"//tensorflow/python:util",
|
"//tensorflow/python:util",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -212,6 +213,27 @@ py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "util",
|
||||||
|
srcs = [
|
||||||
|
"util.py",
|
||||||
|
],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/python:util",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "util_test",
|
||||||
|
srcs = ["util_test.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":util",
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "estimator",
|
name = "estimator",
|
||||||
srcs = [
|
srcs = [
|
||||||
@ -222,6 +244,7 @@ py_library(
|
|||||||
":export",
|
":export",
|
||||||
":model_fn",
|
":model_fn",
|
||||||
":run_config",
|
":run_config",
|
||||||
|
":util",
|
||||||
"//tensorflow/core:protos_all_py",
|
"//tensorflow/core:protos_all_py",
|
||||||
"//tensorflow/python:client",
|
"//tensorflow/python:client",
|
||||||
"//tensorflow/python:control_flow_ops",
|
"//tensorflow/python:control_flow_ops",
|
||||||
|
@ -31,6 +31,7 @@ from tensorflow.core.protobuf import config_pb2
|
|||||||
from tensorflow.python.client import session as tf_session
|
from tensorflow.python.client import session as tf_session
|
||||||
from tensorflow.python.estimator import model_fn as model_fn_lib
|
from tensorflow.python.estimator import model_fn as model_fn_lib
|
||||||
from tensorflow.python.estimator import run_config
|
from tensorflow.python.estimator import run_config
|
||||||
|
from tensorflow.python.estimator import util
|
||||||
from tensorflow.python.estimator.export.export import build_all_signature_defs
|
from tensorflow.python.estimator.export.export import build_all_signature_defs
|
||||||
from tensorflow.python.estimator.export.export import get_timestamped_export_dir
|
from tensorflow.python.estimator.export.export import get_timestamped_export_dir
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
@ -47,7 +48,6 @@ from tensorflow.python.training import monitored_session
|
|||||||
from tensorflow.python.training import saver
|
from tensorflow.python.training import saver
|
||||||
from tensorflow.python.training import training
|
from tensorflow.python.training import training
|
||||||
from tensorflow.python.util import compat
|
from tensorflow.python.util import compat
|
||||||
from tensorflow.python.util import tf_decorator
|
|
||||||
from tensorflow.python.util import tf_inspect
|
from tensorflow.python.util import tf_inspect
|
||||||
|
|
||||||
|
|
||||||
@ -575,7 +575,7 @@ class Estimator(object):
|
|||||||
ValueError: if input_fn takes invalid arguments.
|
ValueError: if input_fn takes invalid arguments.
|
||||||
"""
|
"""
|
||||||
del mode # unused
|
del mode # unused
|
||||||
input_fn_args = _fn_args(input_fn)
|
input_fn_args = util.fn_args(input_fn)
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if 'params' in input_fn_args:
|
if 'params' in input_fn_args:
|
||||||
kwargs['params'] = self.params
|
kwargs['params'] = self.params
|
||||||
@ -598,7 +598,7 @@ class Estimator(object):
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: if model_fn returns invalid objects.
|
ValueError: if model_fn returns invalid objects.
|
||||||
"""
|
"""
|
||||||
model_fn_args = _fn_args(self._model_fn)
|
model_fn_args = util.fn_args(self._model_fn)
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if 'labels' in model_fn_args:
|
if 'labels' in model_fn_args:
|
||||||
kwargs['labels'] = labels
|
kwargs['labels'] = labels
|
||||||
@ -791,35 +791,9 @@ def _get_replica_device_setter(config):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _fn_args(fn):
|
|
||||||
"""Get argument names for function-like object.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
fn: Function, or function-like object (e.g., result of `functools.partial`).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
`tuple` of string argument names.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: if partial function has positionally bound arguments
|
|
||||||
"""
|
|
||||||
_, fn = tf_decorator.unwrap(fn)
|
|
||||||
if hasattr(fn, '__call__') and tf_inspect.ismethod(fn.__call__):
|
|
||||||
# Handle callables.
|
|
||||||
return tuple(tf_inspect.getargspec(fn.__call__).args)
|
|
||||||
if hasattr(fn, 'func') and hasattr(fn, 'keywords') and hasattr(fn, 'args'):
|
|
||||||
# Handle functools.partial and similar objects.
|
|
||||||
return tuple([
|
|
||||||
arg for arg in tf_inspect.getargspec(fn.func).args[len(fn.args):]
|
|
||||||
if arg not in set(fn.keywords.keys())
|
|
||||||
])
|
|
||||||
# Handle function.
|
|
||||||
return tuple(tf_inspect.getargspec(fn).args)
|
|
||||||
|
|
||||||
|
|
||||||
def _verify_model_fn_args(model_fn, params):
|
def _verify_model_fn_args(model_fn, params):
|
||||||
"""Verifies model fn arguments."""
|
"""Verifies model fn arguments."""
|
||||||
args = set(_fn_args(model_fn))
|
args = set(util.fn_args(model_fn))
|
||||||
if 'features' not in args:
|
if 'features' not in args:
|
||||||
raise ValueError('model_fn (%s) must include features argument.' % model_fn)
|
raise ValueError('model_fn (%s) must include features argument.' % model_fn)
|
||||||
if params is not None and 'params' not in args:
|
if params is not None and 'params' not in args:
|
||||||
|
57
tensorflow/python/estimator/util.py
Normal file
57
tensorflow/python/estimator/util.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
"""Utility to retrieve function args.."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.util import tf_decorator
|
||||||
|
from tensorflow.python.util import tf_inspect
|
||||||
|
|
||||||
|
|
||||||
|
def fn_args(fn):
|
||||||
|
"""Get argument names for function-like object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fn: Function, or function-like object (e.g., result of `functools.partial`).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`tuple` of string argument names.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if partial function has positionally bound arguments
|
||||||
|
"""
|
||||||
|
_, fn = tf_decorator.unwrap(fn)
|
||||||
|
|
||||||
|
# Handle callables.
|
||||||
|
if hasattr(fn, '__call__') and tf_inspect.ismethod(fn.__call__):
|
||||||
|
return tuple(tf_inspect.getargspec(fn.__call__).args)
|
||||||
|
|
||||||
|
# Handle functools.partial and similar objects.
|
||||||
|
if hasattr(fn, 'func') and hasattr(fn, 'keywords') and hasattr(fn, 'args'):
|
||||||
|
# Handle nested partial.
|
||||||
|
original_args = fn_args(fn.func)
|
||||||
|
if not original_args:
|
||||||
|
return tuple()
|
||||||
|
|
||||||
|
return tuple([
|
||||||
|
arg for arg in original_args[len(fn.args):]
|
||||||
|
if arg not in set((fn.keywords or {}).keys())
|
||||||
|
])
|
||||||
|
|
||||||
|
# Handle function.
|
||||||
|
return tuple(tf_inspect.getargspec(fn).args)
|
119
tensorflow/python/estimator/util_test.py
Normal file
119
tensorflow/python/estimator/util_test.py
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for Estimator related util."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import functools
|
||||||
|
|
||||||
|
from tensorflow.python.estimator import util
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
class FnArgsTest(test.TestCase):
|
||||||
|
|
||||||
|
def test_simple_function(self):
|
||||||
|
def fn(a, b):
|
||||||
|
return a + b
|
||||||
|
self.assertEqual(('a', 'b'), util.fn_args(fn))
|
||||||
|
|
||||||
|
def test_callable(self):
|
||||||
|
|
||||||
|
class Foo(object):
|
||||||
|
|
||||||
|
def __call__(self, a, b):
|
||||||
|
return a + b
|
||||||
|
|
||||||
|
self.assertEqual(('self', 'a', 'b'), util.fn_args(Foo()))
|
||||||
|
|
||||||
|
def test_partial_function(self):
|
||||||
|
expected_test_arg = 123
|
||||||
|
|
||||||
|
def fn(a, test_arg):
|
||||||
|
if test_arg != expected_test_arg:
|
||||||
|
return ValueError('partial fn does not work correctly')
|
||||||
|
return a
|
||||||
|
|
||||||
|
wrapped_fn = functools.partial(fn, test_arg=123)
|
||||||
|
|
||||||
|
self.assertEqual(('a',), util.fn_args(wrapped_fn))
|
||||||
|
|
||||||
|
def test_partial_function_with_positional_args(self):
|
||||||
|
expected_test_arg = 123
|
||||||
|
|
||||||
|
def fn(test_arg, a):
|
||||||
|
if test_arg != expected_test_arg:
|
||||||
|
return ValueError('partial fn does not work correctly')
|
||||||
|
return a
|
||||||
|
|
||||||
|
wrapped_fn = functools.partial(fn, 123)
|
||||||
|
|
||||||
|
self.assertEqual(('a',), util.fn_args(wrapped_fn))
|
||||||
|
|
||||||
|
self.assertEqual(3, wrapped_fn(3))
|
||||||
|
self.assertEqual(3, wrapped_fn(a=3))
|
||||||
|
|
||||||
|
def test_double_partial(self):
|
||||||
|
expected_test_arg1 = 123
|
||||||
|
expected_test_arg2 = 456
|
||||||
|
|
||||||
|
def fn(a, test_arg1, test_arg2):
|
||||||
|
if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2:
|
||||||
|
return ValueError('partial does not work correctly')
|
||||||
|
return a
|
||||||
|
|
||||||
|
wrapped_fn = functools.partial(fn, test_arg2=456)
|
||||||
|
double_wrapped_fn = functools.partial(wrapped_fn, test_arg1=123)
|
||||||
|
|
||||||
|
self.assertEqual(('a',), util.fn_args(double_wrapped_fn))
|
||||||
|
|
||||||
|
def test_double_partial_with_positional_args_in_outer_layer(self):
|
||||||
|
expected_test_arg1 = 123
|
||||||
|
expected_test_arg2 = 456
|
||||||
|
|
||||||
|
def fn(test_arg1, a, test_arg2):
|
||||||
|
if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2:
|
||||||
|
return ValueError('partial fn does not work correctly')
|
||||||
|
return a
|
||||||
|
|
||||||
|
wrapped_fn = functools.partial(fn, test_arg2=456)
|
||||||
|
double_wrapped_fn = functools.partial(wrapped_fn, 123)
|
||||||
|
|
||||||
|
self.assertEqual(('a',), util.fn_args(double_wrapped_fn))
|
||||||
|
|
||||||
|
self.assertEqual(3, double_wrapped_fn(3))
|
||||||
|
self.assertEqual(3, double_wrapped_fn(a=3))
|
||||||
|
|
||||||
|
def test_double_partial_with_positional_args_in_both_layers(self):
|
||||||
|
expected_test_arg1 = 123
|
||||||
|
expected_test_arg2 = 456
|
||||||
|
|
||||||
|
def fn(test_arg1, test_arg2, a):
|
||||||
|
if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2:
|
||||||
|
return ValueError('partial fn does not work correctly')
|
||||||
|
return a
|
||||||
|
|
||||||
|
wrapped_fn = functools.partial(fn, 123) # binds to test_arg1
|
||||||
|
double_wrapped_fn = functools.partial(wrapped_fn, 456) # binds to test_arg2
|
||||||
|
|
||||||
|
self.assertEqual(('a',), util.fn_args(double_wrapped_fn))
|
||||||
|
|
||||||
|
self.assertEqual(3, double_wrapped_fn(3))
|
||||||
|
self.assertEqual(3, double_wrapped_fn(a=3))
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test.main()
|
Loading…
Reference in New Issue
Block a user