Override "map_resources" in AggregatingVariable.

PiperOrigin-RevId: 317398115
Change-Id: Ic57325aacf0fb45a66a469428f544fb94f3cc031
This commit is contained in:
Chenkai Kuang 2020-06-19 16:08:23 -07:00 committed by TensorFlower Gardener
parent e264e71a44
commit fac99746cb

View File

@ -166,6 +166,14 @@ class AggregatingVariable(variables_lib.Variable, core.Tensor):
def _gather_saveables_for_checkpoint(self):
return {trackable.VARIABLE_VALUE_KEY: self._v}
def _map_resources(self):
"""For implementing `Trackable`."""
# By delegating this method to the wrapped variable, SavedModel with
# AggregatingVariable are identical to SavedModel with normal variables.
obj_map, resource_map = self._v._map_resources() # pylint:disable=protected-access
obj_map[self] = obj_map[self._v]
return obj_map, resource_map
# pylint: disable=multiple-statements
def __add__(self, o):
return self._v + o