Override "map_resources" in AggregatingVariable.
PiperOrigin-RevId: 317398115 Change-Id: Ic57325aacf0fb45a66a469428f544fb94f3cc031
This commit is contained in:
parent
e264e71a44
commit
fac99746cb
@ -166,6 +166,14 @@ class AggregatingVariable(variables_lib.Variable, core.Tensor):
|
|||||||
def _gather_saveables_for_checkpoint(self):
|
def _gather_saveables_for_checkpoint(self):
|
||||||
return {trackable.VARIABLE_VALUE_KEY: self._v}
|
return {trackable.VARIABLE_VALUE_KEY: self._v}
|
||||||
|
|
||||||
|
def _map_resources(self):
|
||||||
|
"""For implementing `Trackable`."""
|
||||||
|
# By delegating this method to the wrapped variable, SavedModel with
|
||||||
|
# AggregatingVariable are identical to SavedModel with normal variables.
|
||||||
|
obj_map, resource_map = self._v._map_resources() # pylint:disable=protected-access
|
||||||
|
obj_map[self] = obj_map[self._v]
|
||||||
|
return obj_map, resource_map
|
||||||
|
|
||||||
# pylint: disable=multiple-statements
|
# pylint: disable=multiple-statements
|
||||||
def __add__(self, o):
|
def __add__(self, o):
|
||||||
return self._v + o
|
return self._v + o
|
||||||
|
Loading…
x
Reference in New Issue
Block a user