From d78a3d1cbe3720a9b471cfba57696014dc8eaab0 Mon Sep 17 00:00:00 2001 From: Katherine Wu Date: Fri, 5 Apr 2019 18:51:58 -0700 Subject: [PATCH] Add trackable ModelFunction object. This will be used to export Estimators to the V2 SavedModel style. Currently exports these arguments from the EstimatorSpec - Predictions - Train op Still need to add: - losses - export_outputs - init_op (scaffold's local_init_op) - eval_metric_ops In addition, made minor changes to training.global_step to make the functions compatible with eager mode, and fixed an issue in model_fn.train_op. PiperOrigin-RevId: 242231618 --- tensorflow/python/eager/def_function.py | 12 ++++++++---- tensorflow/python/training/training_util.py | 5 +++-- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/tensorflow/python/eager/def_function.py b/tensorflow/python/eager/def_function.py index a150baafd60..9b46001506b 100644 --- a/tensorflow/python/eager/def_function.py +++ b/tensorflow/python/eager/def_function.py @@ -316,12 +316,16 @@ class Function(object): return weak_wrapped_fn().__wrapped__(*args, **kwds) weak_wrapped_fn = weakref.ref(wrapped_fn) + return self._defun(tf_decorator.make_decorator( + self._python_function, + wrapped_fn, + decorator_argspec=self._function_spec.fullargspec)) + + def _defun(self, fn): + """Returns a defun generated from the input function.""" # TODO(mdan): Pipe self._experimental_autograph_options through. return function_lib.defun( - tf_decorator.make_decorator( - self._python_function, - wrapped_fn, - decorator_argspec=self._function_spec.fullargspec), + fn, input_signature=self.input_signature, autograph=self._autograph, experimental_autograph_options=self._experimental_autograph_options) diff --git a/tensorflow/python/training/training_util.py b/tensorflow/python/training/training_util.py index 86f1b4d5aae..d526a17297f 100644 --- a/tensorflow/python/training/training_util.py +++ b/tensorflow/python/training/training_util.py @@ -230,7 +230,7 @@ def _get_or_create_global_step_read(graph=None): return None # add 'zero' so that it will create a copy of variable as Tensor. with graph.as_default() as g, g.name_scope(None): - with g.name_scope(global_step_tensor.op.name + '/'): + with g.name_scope(global_step_tensor._shared_name + '/'): # pylint: disable=protected-access # using initialized_value to ensure that global_step is initialized before # this run. This is needed for example Estimator makes all model_fn build # under global_step_read_tensor dependency. @@ -242,6 +242,7 @@ def _get_or_create_global_step_read(graph=None): def _increment_global_step(increment, graph=None): + """Increments the global step variable.""" graph = graph or ops.get_default_graph() global_step_tensor = get_global_step(graph) if global_step_tensor is None: @@ -250,6 +251,6 @@ def _increment_global_step(increment, graph=None): 'tf.train.get_or_create_global_step before calling increment.') global_step_read_tensor = _get_or_create_global_step_read(graph) with graph.as_default() as g, g.name_scope(None): - with g.name_scope(global_step_tensor.op.name + '/'): + with g.name_scope(global_step_tensor._shared_name + '/'): # pylint: disable=protected-access with ops.control_dependencies([global_step_read_tensor]): return state_ops.assign_add(global_step_tensor, increment)