Add trackable ModelFunction object. This will be used to export Estimators to the V2 SavedModel style.
Currently exports these arguments from the EstimatorSpec - Predictions - Train op Still need to add: - losses - export_outputs - init_op (scaffold's local_init_op) - eval_metric_ops In addition, made minor changes to training.global_step to make the functions compatible with eager mode, and fixed an issue in model_fn.train_op. PiperOrigin-RevId: 242231618
This commit is contained in:
parent
3e7afacf1d
commit
d78a3d1cbe
@ -316,12 +316,16 @@ class Function(object):
|
||||
return weak_wrapped_fn().__wrapped__(*args, **kwds)
|
||||
weak_wrapped_fn = weakref.ref(wrapped_fn)
|
||||
|
||||
return self._defun(tf_decorator.make_decorator(
|
||||
self._python_function,
|
||||
wrapped_fn,
|
||||
decorator_argspec=self._function_spec.fullargspec))
|
||||
|
||||
def _defun(self, fn):
|
||||
"""Returns a defun generated from the input function."""
|
||||
# TODO(mdan): Pipe self._experimental_autograph_options through.
|
||||
return function_lib.defun(
|
||||
tf_decorator.make_decorator(
|
||||
self._python_function,
|
||||
wrapped_fn,
|
||||
decorator_argspec=self._function_spec.fullargspec),
|
||||
fn,
|
||||
input_signature=self.input_signature,
|
||||
autograph=self._autograph,
|
||||
experimental_autograph_options=self._experimental_autograph_options)
|
||||
|
@ -230,7 +230,7 @@ def _get_or_create_global_step_read(graph=None):
|
||||
return None
|
||||
# add 'zero' so that it will create a copy of variable as Tensor.
|
||||
with graph.as_default() as g, g.name_scope(None):
|
||||
with g.name_scope(global_step_tensor.op.name + '/'):
|
||||
with g.name_scope(global_step_tensor._shared_name + '/'): # pylint: disable=protected-access
|
||||
# using initialized_value to ensure that global_step is initialized before
|
||||
# this run. This is needed for example Estimator makes all model_fn build
|
||||
# under global_step_read_tensor dependency.
|
||||
@ -242,6 +242,7 @@ def _get_or_create_global_step_read(graph=None):
|
||||
|
||||
|
||||
def _increment_global_step(increment, graph=None):
|
||||
"""Increments the global step variable."""
|
||||
graph = graph or ops.get_default_graph()
|
||||
global_step_tensor = get_global_step(graph)
|
||||
if global_step_tensor is None:
|
||||
@ -250,6 +251,6 @@ def _increment_global_step(increment, graph=None):
|
||||
'tf.train.get_or_create_global_step before calling increment.')
|
||||
global_step_read_tensor = _get_or_create_global_step_read(graph)
|
||||
with graph.as_default() as g, g.name_scope(None):
|
||||
with g.name_scope(global_step_tensor.op.name + '/'):
|
||||
with g.name_scope(global_step_tensor._shared_name + '/'): # pylint: disable=protected-access
|
||||
with ops.control_dependencies([global_step_read_tensor]):
|
||||
return state_ops.assign_add(global_step_tensor, increment)
|
||||
|
Loading…
Reference in New Issue
Block a user