From e7fc162658db3c5435fbca17e98fcce553e1648d Mon Sep 17 00:00:00 2001 From: Jianwei Xie Date: Thu, 6 Jul 2017 12:18:21 -0700 Subject: [PATCH] Consolidate/Merge the fn_args in multiple places (with double partial) and add full tests. RELNOTES: n/a PiperOrigin-RevId: 161113304 --- .../contrib/tpu/python/tpu/tpu_estimator.py | 5 +- tensorflow/python/estimator/BUILD | 23 ++++ tensorflow/python/estimator/estimator.py | 34 +---- tensorflow/python/estimator/util.py | 57 +++++++++ tensorflow/python/estimator/util_test.py | 119 ++++++++++++++++++ 5 files changed, 206 insertions(+), 32 deletions(-) create mode 100644 tensorflow/python/estimator/util.py create mode 100644 tensorflow/python/estimator/util_test.py diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index b9da8dc35ab..e001d866c35 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -30,6 +30,7 @@ from tensorflow.contrib.tpu.python.tpu import training_loop from tensorflow.python.estimator import estimator as estimator_lib from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.estimator import util from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -303,7 +304,7 @@ class TpuEstimator(estimator_lib.Estimator): if not self._use_tpu or mode != model_fn_lib.ModeKeys.TRAIN: return super(TpuEstimator, self)._call_input_fn(input_fn, mode) - input_fn_args = estimator_lib._fn_args(input_fn) # pylint: disable=protected-access + input_fn_args = util.fn_args(input_fn) config = self.config # a deep copy. kwargs = {} if 'params' in input_fn_args: @@ -357,7 +358,7 @@ def _verify_estimator_spec(estimator_spec): def _call_model_fn(model_fn, features, labels, mode, config, params, require_params=False): """Calls the model_fn with required parameters.""" - model_fn_args = estimator_lib._fn_args(model_fn) # pylint: disable=protected-access + model_fn_args = util.fn_args(model_fn) kwargs = {} if 'mode' in model_fn_args: kwargs['mode'] = mode diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD index ca46daf3065..627fa8b25ec 100644 --- a/tensorflow/python/estimator/BUILD +++ b/tensorflow/python/estimator/BUILD @@ -26,6 +26,7 @@ py_library( ":model_fn", ":parsing_utils", ":run_config", + ":util", "//tensorflow/python:util", ], ) @@ -212,6 +213,27 @@ py_test( ], ) +py_library( + name = "util", + srcs = [ + "util.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:util", + ], +) + +py_test( + name = "util_test", + srcs = ["util_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":util", + "//tensorflow/python:client_testlib", + ], +) + py_library( name = "estimator", srcs = [ @@ -222,6 +244,7 @@ py_library( ":export", ":model_fn", ":run_config", + ":util", "//tensorflow/core:protos_all_py", "//tensorflow/python:client", "//tensorflow/python:control_flow_ops", diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index bc11a6c84f2..0c1bf7ccf53 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -31,6 +31,7 @@ from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session as tf_session from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator import run_config +from tensorflow.python.estimator import util from tensorflow.python.estimator.export.export import build_all_signature_defs from tensorflow.python.estimator.export.export import get_timestamped_export_dir from tensorflow.python.framework import ops @@ -47,7 +48,6 @@ from tensorflow.python.training import monitored_session from tensorflow.python.training import saver from tensorflow.python.training import training from tensorflow.python.util import compat -from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_inspect @@ -575,7 +575,7 @@ class Estimator(object): ValueError: if input_fn takes invalid arguments. """ del mode # unused - input_fn_args = _fn_args(input_fn) + input_fn_args = util.fn_args(input_fn) kwargs = {} if 'params' in input_fn_args: kwargs['params'] = self.params @@ -598,7 +598,7 @@ class Estimator(object): Raises: ValueError: if model_fn returns invalid objects. """ - model_fn_args = _fn_args(self._model_fn) + model_fn_args = util.fn_args(self._model_fn) kwargs = {} if 'labels' in model_fn_args: kwargs['labels'] = labels @@ -791,35 +791,9 @@ def _get_replica_device_setter(config): return None -def _fn_args(fn): - """Get argument names for function-like object. - - Args: - fn: Function, or function-like object (e.g., result of `functools.partial`). - - Returns: - `tuple` of string argument names. - - Raises: - ValueError: if partial function has positionally bound arguments - """ - _, fn = tf_decorator.unwrap(fn) - if hasattr(fn, '__call__') and tf_inspect.ismethod(fn.__call__): - # Handle callables. - return tuple(tf_inspect.getargspec(fn.__call__).args) - if hasattr(fn, 'func') and hasattr(fn, 'keywords') and hasattr(fn, 'args'): - # Handle functools.partial and similar objects. - return tuple([ - arg for arg in tf_inspect.getargspec(fn.func).args[len(fn.args):] - if arg not in set(fn.keywords.keys()) - ]) - # Handle function. - return tuple(tf_inspect.getargspec(fn).args) - - def _verify_model_fn_args(model_fn, params): """Verifies model fn arguments.""" - args = set(_fn_args(model_fn)) + args = set(util.fn_args(model_fn)) if 'features' not in args: raise ValueError('model_fn (%s) must include features argument.' % model_fn) if params is not None and 'params' not in args: diff --git a/tensorflow/python/estimator/util.py b/tensorflow/python/estimator/util.py new file mode 100644 index 00000000000..de35e66bdfb --- /dev/null +++ b/tensorflow/python/estimator/util.py @@ -0,0 +1,57 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Utility to retrieve function args..""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.util import tf_decorator +from tensorflow.python.util import tf_inspect + + +def fn_args(fn): + """Get argument names for function-like object. + + Args: + fn: Function, or function-like object (e.g., result of `functools.partial`). + + Returns: + `tuple` of string argument names. + + Raises: + ValueError: if partial function has positionally bound arguments + """ + _, fn = tf_decorator.unwrap(fn) + + # Handle callables. + if hasattr(fn, '__call__') and tf_inspect.ismethod(fn.__call__): + return tuple(tf_inspect.getargspec(fn.__call__).args) + + # Handle functools.partial and similar objects. + if hasattr(fn, 'func') and hasattr(fn, 'keywords') and hasattr(fn, 'args'): + # Handle nested partial. + original_args = fn_args(fn.func) + if not original_args: + return tuple() + + return tuple([ + arg for arg in original_args[len(fn.args):] + if arg not in set((fn.keywords or {}).keys()) + ]) + + # Handle function. + return tuple(tf_inspect.getargspec(fn).args) diff --git a/tensorflow/python/estimator/util_test.py b/tensorflow/python/estimator/util_test.py new file mode 100644 index 00000000000..3f8122c407b --- /dev/null +++ b/tensorflow/python/estimator/util_test.py @@ -0,0 +1,119 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Estimator related util.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools + +from tensorflow.python.estimator import util +from tensorflow.python.platform import test + + +class FnArgsTest(test.TestCase): + + def test_simple_function(self): + def fn(a, b): + return a + b + self.assertEqual(('a', 'b'), util.fn_args(fn)) + + def test_callable(self): + + class Foo(object): + + def __call__(self, a, b): + return a + b + + self.assertEqual(('self', 'a', 'b'), util.fn_args(Foo())) + + def test_partial_function(self): + expected_test_arg = 123 + + def fn(a, test_arg): + if test_arg != expected_test_arg: + return ValueError('partial fn does not work correctly') + return a + + wrapped_fn = functools.partial(fn, test_arg=123) + + self.assertEqual(('a',), util.fn_args(wrapped_fn)) + + def test_partial_function_with_positional_args(self): + expected_test_arg = 123 + + def fn(test_arg, a): + if test_arg != expected_test_arg: + return ValueError('partial fn does not work correctly') + return a + + wrapped_fn = functools.partial(fn, 123) + + self.assertEqual(('a',), util.fn_args(wrapped_fn)) + + self.assertEqual(3, wrapped_fn(3)) + self.assertEqual(3, wrapped_fn(a=3)) + + def test_double_partial(self): + expected_test_arg1 = 123 + expected_test_arg2 = 456 + + def fn(a, test_arg1, test_arg2): + if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2: + return ValueError('partial does not work correctly') + return a + + wrapped_fn = functools.partial(fn, test_arg2=456) + double_wrapped_fn = functools.partial(wrapped_fn, test_arg1=123) + + self.assertEqual(('a',), util.fn_args(double_wrapped_fn)) + + def test_double_partial_with_positional_args_in_outer_layer(self): + expected_test_arg1 = 123 + expected_test_arg2 = 456 + + def fn(test_arg1, a, test_arg2): + if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2: + return ValueError('partial fn does not work correctly') + return a + + wrapped_fn = functools.partial(fn, test_arg2=456) + double_wrapped_fn = functools.partial(wrapped_fn, 123) + + self.assertEqual(('a',), util.fn_args(double_wrapped_fn)) + + self.assertEqual(3, double_wrapped_fn(3)) + self.assertEqual(3, double_wrapped_fn(a=3)) + + def test_double_partial_with_positional_args_in_both_layers(self): + expected_test_arg1 = 123 + expected_test_arg2 = 456 + + def fn(test_arg1, test_arg2, a): + if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2: + return ValueError('partial fn does not work correctly') + return a + + wrapped_fn = functools.partial(fn, 123) # binds to test_arg1 + double_wrapped_fn = functools.partial(wrapped_fn, 456) # binds to test_arg2 + + self.assertEqual(('a',), util.fn_args(double_wrapped_fn)) + + self.assertEqual(3, double_wrapped_fn(3)) + self.assertEqual(3, double_wrapped_fn(a=3)) + +if __name__ == '__main__': + test.main()