AutoGraph now supports new creating new symbols in a loop, so this instance can be simplified. The code will raise a runtime error if steps_per_execution is zero. Although this should not reduce the tracing time significantly (the loop body is still traced twice internally), it should reduce the graph size as the initial iteration is discarded.

PiperOrigin-RevId: 326542189
Change-Id: Ia5db6003d07ea33d4d6f44c2cdb2e6133586f72d
This commit is contained in:
Dan Moldovan 2020-08-13 15:57:02 -07:00 committed by TensorFlower Gardener
parent cc64f9e424
commit 2a1992b8db

View File

@ -780,8 +780,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
def train_function(iterator):
"""Runs a training execution with multiple steps."""
outputs = step_function(self, iterator)
for _ in math_ops.range(self._steps_per_execution - 1):
for _ in math_ops.range(self._steps_per_execution):
outputs = step_function(self, iterator)
return outputs
@ -1201,8 +1200,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
def test_function(iterator):
"""Runs an evaluation execution with multiple steps."""
outputs = step_function(self, iterator)
for _ in math_ops.range(self._steps_per_execution - 1):
for _ in math_ops.range(self._steps_per_execution):
outputs = step_function(self, iterator)
return outputs