Add trackable ModelFunction object. This will be used to export Estimators to the V2 SavedModel style.

Currently exports these arguments from the EstimatorSpec
- Predictions
- Train op

Still need to add:
- losses
- export_outputs
- init_op (scaffold's local_init_op)
- eval_metric_ops

In addition, made minor changes to training.global_step to make the functions compatible with eager mode, and fixed an issue in model_fn.train_op.

PiperOrigin-RevId: 242231618
This commit is contained in:
Katherine Wu 2019-04-05 18:51:58 -07:00 committed by TensorFlower Gardener
parent 3e7afacf1d
commit d78a3d1cbe
2 changed files with 11 additions and 6 deletions

View File

@ -316,12 +316,16 @@ class Function(object):
return weak_wrapped_fn().__wrapped__(*args, **kwds) return weak_wrapped_fn().__wrapped__(*args, **kwds)
weak_wrapped_fn = weakref.ref(wrapped_fn) weak_wrapped_fn = weakref.ref(wrapped_fn)
# TODO(mdan): Pipe self._experimental_autograph_options through. return self._defun(tf_decorator.make_decorator(
return function_lib.defun(
tf_decorator.make_decorator(
self._python_function, self._python_function,
wrapped_fn, wrapped_fn,
decorator_argspec=self._function_spec.fullargspec), decorator_argspec=self._function_spec.fullargspec))
def _defun(self, fn):
"""Returns a defun generated from the input function."""
# TODO(mdan): Pipe self._experimental_autograph_options through.
return function_lib.defun(
fn,
input_signature=self.input_signature, input_signature=self.input_signature,
autograph=self._autograph, autograph=self._autograph,
experimental_autograph_options=self._experimental_autograph_options) experimental_autograph_options=self._experimental_autograph_options)

View File

@ -230,7 +230,7 @@ def _get_or_create_global_step_read(graph=None):
return None return None
# add 'zero' so that it will create a copy of variable as Tensor. # add 'zero' so that it will create a copy of variable as Tensor.
with graph.as_default() as g, g.name_scope(None): with graph.as_default() as g, g.name_scope(None):
with g.name_scope(global_step_tensor.op.name + '/'): with g.name_scope(global_step_tensor._shared_name + '/'): # pylint: disable=protected-access
# using initialized_value to ensure that global_step is initialized before # using initialized_value to ensure that global_step is initialized before
# this run. This is needed for example Estimator makes all model_fn build # this run. This is needed for example Estimator makes all model_fn build
# under global_step_read_tensor dependency. # under global_step_read_tensor dependency.
@ -242,6 +242,7 @@ def _get_or_create_global_step_read(graph=None):
def _increment_global_step(increment, graph=None): def _increment_global_step(increment, graph=None):
"""Increments the global step variable."""
graph = graph or ops.get_default_graph() graph = graph or ops.get_default_graph()
global_step_tensor = get_global_step(graph) global_step_tensor = get_global_step(graph)
if global_step_tensor is None: if global_step_tensor is None:
@ -250,6 +251,6 @@ def _increment_global_step(increment, graph=None):
'tf.train.get_or_create_global_step before calling increment.') 'tf.train.get_or_create_global_step before calling increment.')
global_step_read_tensor = _get_or_create_global_step_read(graph) global_step_read_tensor = _get_or_create_global_step_read(graph)
with graph.as_default() as g, g.name_scope(None): with graph.as_default() as g, g.name_scope(None):
with g.name_scope(global_step_tensor.op.name + '/'): with g.name_scope(global_step_tensor._shared_name + '/'): # pylint: disable=protected-access
with ops.control_dependencies([global_step_read_tensor]): with ops.control_dependencies([global_step_read_tensor]):
return state_ops.assign_add(global_step_tensor, increment) return state_ops.assign_add(global_step_tensor, increment)