Add trackable ModelFunction object. This will be used to export Estimators to the V2 SavedModel style.
Currently exports these arguments from the EstimatorSpec - Predictions - Train op Still need to add: - losses - export_outputs - init_op (scaffold's local_init_op) - eval_metric_ops In addition, made minor changes to training.global_step to make the functions compatible with eager mode, and fixed an issue in model_fn.train_op. PiperOrigin-RevId: 242231618
This commit is contained in:
parent
3e7afacf1d
commit
d78a3d1cbe
@ -316,12 +316,16 @@ class Function(object):
|
|||||||
return weak_wrapped_fn().__wrapped__(*args, **kwds)
|
return weak_wrapped_fn().__wrapped__(*args, **kwds)
|
||||||
weak_wrapped_fn = weakref.ref(wrapped_fn)
|
weak_wrapped_fn = weakref.ref(wrapped_fn)
|
||||||
|
|
||||||
|
return self._defun(tf_decorator.make_decorator(
|
||||||
|
self._python_function,
|
||||||
|
wrapped_fn,
|
||||||
|
decorator_argspec=self._function_spec.fullargspec))
|
||||||
|
|
||||||
|
def _defun(self, fn):
|
||||||
|
"""Returns a defun generated from the input function."""
|
||||||
# TODO(mdan): Pipe self._experimental_autograph_options through.
|
# TODO(mdan): Pipe self._experimental_autograph_options through.
|
||||||
return function_lib.defun(
|
return function_lib.defun(
|
||||||
tf_decorator.make_decorator(
|
fn,
|
||||||
self._python_function,
|
|
||||||
wrapped_fn,
|
|
||||||
decorator_argspec=self._function_spec.fullargspec),
|
|
||||||
input_signature=self.input_signature,
|
input_signature=self.input_signature,
|
||||||
autograph=self._autograph,
|
autograph=self._autograph,
|
||||||
experimental_autograph_options=self._experimental_autograph_options)
|
experimental_autograph_options=self._experimental_autograph_options)
|
||||||
|
@ -230,7 +230,7 @@ def _get_or_create_global_step_read(graph=None):
|
|||||||
return None
|
return None
|
||||||
# add 'zero' so that it will create a copy of variable as Tensor.
|
# add 'zero' so that it will create a copy of variable as Tensor.
|
||||||
with graph.as_default() as g, g.name_scope(None):
|
with graph.as_default() as g, g.name_scope(None):
|
||||||
with g.name_scope(global_step_tensor.op.name + '/'):
|
with g.name_scope(global_step_tensor._shared_name + '/'): # pylint: disable=protected-access
|
||||||
# using initialized_value to ensure that global_step is initialized before
|
# using initialized_value to ensure that global_step is initialized before
|
||||||
# this run. This is needed for example Estimator makes all model_fn build
|
# this run. This is needed for example Estimator makes all model_fn build
|
||||||
# under global_step_read_tensor dependency.
|
# under global_step_read_tensor dependency.
|
||||||
@ -242,6 +242,7 @@ def _get_or_create_global_step_read(graph=None):
|
|||||||
|
|
||||||
|
|
||||||
def _increment_global_step(increment, graph=None):
|
def _increment_global_step(increment, graph=None):
|
||||||
|
"""Increments the global step variable."""
|
||||||
graph = graph or ops.get_default_graph()
|
graph = graph or ops.get_default_graph()
|
||||||
global_step_tensor = get_global_step(graph)
|
global_step_tensor = get_global_step(graph)
|
||||||
if global_step_tensor is None:
|
if global_step_tensor is None:
|
||||||
@ -250,6 +251,6 @@ def _increment_global_step(increment, graph=None):
|
|||||||
'tf.train.get_or_create_global_step before calling increment.')
|
'tf.train.get_or_create_global_step before calling increment.')
|
||||||
global_step_read_tensor = _get_or_create_global_step_read(graph)
|
global_step_read_tensor = _get_or_create_global_step_read(graph)
|
||||||
with graph.as_default() as g, g.name_scope(None):
|
with graph.as_default() as g, g.name_scope(None):
|
||||||
with g.name_scope(global_step_tensor.op.name + '/'):
|
with g.name_scope(global_step_tensor._shared_name + '/'): # pylint: disable=protected-access
|
||||||
with ops.control_dependencies([global_step_read_tensor]):
|
with ops.control_dependencies([global_step_read_tensor]):
|
||||||
return state_ops.assign_add(global_step_tensor, increment)
|
return state_ops.assign_add(global_step_tensor, increment)
|
||||||
|
Loading…
Reference in New Issue
Block a user