Deterministic sort of keras network _checkpoint_dependencies to enable MultiWorkerMirroredStrategy checkpointing and prevent different tensors being broadcast for the same CollectiveOps key.

PiperOrigin-RevId: 299153957
Change-Id: I375d4d8f9c5e256c29328f9cf887d29936224849
This commit is contained in:
Kathy Ruan 2020-03-05 11:43:48 -08:00 committed by TensorFlower Gardener
parent d0b139a856
commit 9d7566d0cb

View File

@ -411,7 +411,7 @@ class Network(base_layer.Layer):
def _checkpoint_dependencies(self):
dependencies = [
trackable.TrackableReference(name=name, ref=layer)
for name, layer in self._layer_checkpoint_dependencies.items()]
for name, layer in sorted(self._layer_checkpoint_dependencies.items())]
dependencies.extend(super(Network, self)._checkpoint_dependencies)
return dependencies