Override "map_resources" in AggregatingVariable.
PiperOrigin-RevId: 317398115 Change-Id: Ic57325aacf0fb45a66a469428f544fb94f3cc031
This commit is contained in:
parent
e264e71a44
commit
fac99746cb
@ -166,6 +166,14 @@ class AggregatingVariable(variables_lib.Variable, core.Tensor):
|
||||
def _gather_saveables_for_checkpoint(self):
|
||||
return {trackable.VARIABLE_VALUE_KEY: self._v}
|
||||
|
||||
def _map_resources(self):
|
||||
"""For implementing `Trackable`."""
|
||||
# By delegating this method to the wrapped variable, SavedModel with
|
||||
# AggregatingVariable are identical to SavedModel with normal variables.
|
||||
obj_map, resource_map = self._v._map_resources() # pylint:disable=protected-access
|
||||
obj_map[self] = obj_map[self._v]
|
||||
return obj_map, resource_map
|
||||
|
||||
# pylint: disable=multiple-statements
|
||||
def __add__(self, o):
|
||||
return self._v + o
|
||||
|
Loading…
Reference in New Issue
Block a user