Expose TrackableReference as tf.__internal__

PiperOrigin-RevId: 361175756
Change-Id: Iba1d8a079afc1784b3318a8c0a57cdb7bf281294
This commit is contained in:
Matt Watson 2021-03-05 10:33:19 -08:00 committed by TensorFlower Gardener
parent 174fbcfed4
commit caaea7ca1a
3 changed files with 40 additions and 8 deletions

View File

@ -45,14 +45,19 @@ OBJECT_GRAPH_PROTO_KEY = "_CHECKPOINTABLE_OBJECT_GRAPH"
VARIABLE_VALUE_KEY = "VARIABLE_VALUE"
OBJECT_CONFIG_JSON_KEY = "OBJECT_CONFIG_JSON"
TrackableReference = collections.namedtuple(
"TrackableReference",
[
# The local name for this dependency.
"name",
# The Trackable object being referenced.
"ref"
])
@tf_export("__internal__.tracking.TrackableReference", v1=[])
class TrackableReference(
collections.namedtuple("TrackableReference", ["name", "ref"])):
"""A named reference to a trackable object for use with the `Trackable` class.
These references mark named `Trackable` dependencies of a `Trackable` object
and should be created when overriding `Trackable._checkpoint_dependencies`.
Attributes:
name: The local name for this dependency.
ref: The `Trackable` object being referenced.
"""
# TODO(bfontain): Update once sharded initialization interface is finalized.

View File

@ -0,0 +1,23 @@
path: "tensorflow.__internal__.tracking.TrackableReference"
tf_class {
is_instance: "<class \'tensorflow.python.training.tracking.base.TrackableReference\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.TrackableReference\'>"
is_instance: "<type \'tuple\'>"
member {
name: "name"
mtype: "<type \'property\'>"
}
member {
name: "ref"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
}
member_method {
name: "count"
}
member_method {
name: "index"
}
}

View File

@ -8,6 +8,10 @@ tf_module {
name: "Trackable"
mtype: "<type \'type\'>"
}
member {
name: "TrackableReference"
mtype: "<type \'type\'>"
}
member_method {
name: "wrap"
argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None"