AutoGraph now supports new creating new symbols in a loop, so this instance can be simplified. The code will raise a runtime error if steps_per_execution is zero. Although this should not reduce the tracing time significantly (the loop body is still traced twice internally), it should reduce the graph size as the initial iteration is discarded.
PiperOrigin-RevId: 326542189 Change-Id: Ia5db6003d07ea33d4d6f44c2cdb2e6133586f72d
This commit is contained in:
parent
cc64f9e424
commit
2a1992b8db
@ -780,8 +780,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
|
||||
|
||||
def train_function(iterator):
|
||||
"""Runs a training execution with multiple steps."""
|
||||
outputs = step_function(self, iterator)
|
||||
for _ in math_ops.range(self._steps_per_execution - 1):
|
||||
for _ in math_ops.range(self._steps_per_execution):
|
||||
outputs = step_function(self, iterator)
|
||||
return outputs
|
||||
|
||||
@ -1201,8 +1200,7 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
|
||||
|
||||
def test_function(iterator):
|
||||
"""Runs an evaluation execution with multiple steps."""
|
||||
outputs = step_function(self, iterator)
|
||||
for _ in math_ops.range(self._steps_per_execution - 1):
|
||||
for _ in math_ops.range(self._steps_per_execution):
|
||||
outputs = step_function(self, iterator)
|
||||
return outputs
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user