diff --git a/tensorflow/python/training/tracking/base.py b/tensorflow/python/training/tracking/base.py index 67605c8fc67..5f6cff77a97 100644 --- a/tensorflow/python/training/tracking/base.py +++ b/tensorflow/python/training/tracking/base.py @@ -45,14 +45,19 @@ OBJECT_GRAPH_PROTO_KEY = "_CHECKPOINTABLE_OBJECT_GRAPH" VARIABLE_VALUE_KEY = "VARIABLE_VALUE" OBJECT_CONFIG_JSON_KEY = "OBJECT_CONFIG_JSON" -TrackableReference = collections.namedtuple( - "TrackableReference", - [ - # The local name for this dependency. - "name", - # The Trackable object being referenced. - "ref" - ]) + +@tf_export("__internal__.tracking.TrackableReference", v1=[]) +class TrackableReference( + collections.namedtuple("TrackableReference", ["name", "ref"])): + """A named reference to a trackable object for use with the `Trackable` class. + + These references mark named `Trackable` dependencies of a `Trackable` object + and should be created when overriding `Trackable._checkpoint_dependencies`. + + Attributes: + name: The local name for this dependency. + ref: The `Trackable` object being referenced. + """ # TODO(bfontain): Update once sharded initialization interface is finalized. diff --git a/tensorflow/tools/api/golden/v2/tensorflow.__internal__.tracking.-trackable-reference.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.tracking.-trackable-reference.pbtxt new file mode 100644 index 00000000000..447171bedac --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.tracking.-trackable-reference.pbtxt @@ -0,0 +1,23 @@ +path: "tensorflow.__internal__.tracking.TrackableReference" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "name" + mtype: "" + } + member { + name: "ref" + mtype: "" + } + member_method { + name: "__init__" + } + member_method { + name: "count" + } + member_method { + name: "index" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.__internal__.tracking.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.tracking.pbtxt index 223a2472e93..97b44fd9714 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.__internal__.tracking.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.__internal__.tracking.pbtxt @@ -8,6 +8,10 @@ tf_module { name: "Trackable" mtype: "" } + member { + name: "TrackableReference" + mtype: "" + } member_method { name: "wrap" argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None"