Add expand_composite
option to tf.Module._flatten
to support expanding composite tensors.
Use the said `expand_composite` option in `tf.Module.variables` like APIs to support expanding composite tensors that are a collection of variables. PiperOrigin-RevId: 338300859 Change-Id: Iddcb1f34a87557e9de15d1887fed6d3b61319301
This commit is contained in:
parent
1c785764d4
commit
bb41d04dd0
tensorflow/python/module
@ -27,9 +27,18 @@ tf_py_test(
|
||||
deps = [
|
||||
":module",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:composite_tensor",
|
||||
"//tensorflow/python:extra_py_tests_deps",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:tf2",
|
||||
"//tensorflow/python:type_spec",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/compat:v2_compat",
|
||||
"//tensorflow/python/distribute:ps_values",
|
||||
"//tensorflow/python/distribute:tpu_values",
|
||||
"//tensorflow/python/distribute:values",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
@ -157,7 +157,7 @@ class Module(tracking.AutoTrackable):
|
||||
name) followed by variables from all submodules recursively (breadth
|
||||
first).
|
||||
"""
|
||||
return tuple(self._flatten(predicate=_is_variable))
|
||||
return tuple(self._flatten(predicate=_is_variable, expand_composites=True))
|
||||
|
||||
@property
|
||||
def trainable_variables(self):
|
||||
@ -172,7 +172,8 @@ class Module(tracking.AutoTrackable):
|
||||
name) followed by variables from all submodules recursively (breadth
|
||||
first).
|
||||
"""
|
||||
return tuple(self._flatten(predicate=_is_trainable_variable))
|
||||
return tuple(
|
||||
self._flatten(predicate=_is_trainable_variable, expand_composites=True))
|
||||
|
||||
@property
|
||||
def submodules(self):
|
||||
@ -202,7 +203,8 @@ class Module(tracking.AutoTrackable):
|
||||
recursive=True,
|
||||
predicate=None,
|
||||
attribute_traversal_key=None,
|
||||
with_path=False):
|
||||
with_path=False,
|
||||
expand_composites=False):
|
||||
"""Flattened attribute values in sorted order by attribute name.
|
||||
|
||||
Modules are flattened by first walking their attributes in name order.
|
||||
@ -247,6 +249,8 @@ class Module(tracking.AutoTrackable):
|
||||
as the object itself. If `with_path` is `True` then leaves will not be
|
||||
de-duplicated (e.g. if the same leaf instance is reachable via multiple
|
||||
modules then it will be yielded multiple times with different paths).
|
||||
expand_composites: If true, then composite tensors are expanded into their
|
||||
component tensors.
|
||||
|
||||
Returns:
|
||||
Flat generator for leaves of the current module and optionally all
|
||||
@ -261,7 +265,8 @@ class Module(tracking.AutoTrackable):
|
||||
predicate=predicate,
|
||||
attributes_to_ignore=self._TF_MODULE_IGNORED_PROPERTIES,
|
||||
attribute_traversal_key=attribute_traversal_key,
|
||||
with_path=with_path)
|
||||
with_path=with_path,
|
||||
expand_composites=expand_composites)
|
||||
|
||||
@classmethod
|
||||
def with_name_scope(cls, method):
|
||||
@ -326,6 +331,7 @@ def _flatten_module(module,
|
||||
attribute_traversal_key,
|
||||
attributes_to_ignore,
|
||||
with_path,
|
||||
expand_composites,
|
||||
module_path=(),
|
||||
seen=None):
|
||||
"""Implementation of `flatten`."""
|
||||
@ -341,7 +347,8 @@ def _flatten_module(module,
|
||||
|
||||
prop = module_dict[key]
|
||||
try:
|
||||
leaves = nest.flatten_with_tuple_paths(prop)
|
||||
leaves = nest.flatten_with_tuple_paths(
|
||||
prop, expand_composites=expand_composites)
|
||||
except Exception as cause: # pylint: disable=broad-except
|
||||
six.raise_from(
|
||||
ValueError(
|
||||
@ -376,6 +383,7 @@ def _flatten_module(module,
|
||||
attribute_traversal_key=attribute_traversal_key,
|
||||
attributes_to_ignore=submodule._TF_MODULE_IGNORED_PROPERTIES, # pylint: disable=protected-access
|
||||
with_path=with_path,
|
||||
expand_composites=expand_composites,
|
||||
module_path=submodule_path,
|
||||
seen=seen)
|
||||
|
||||
|
@ -31,8 +31,10 @@ from tensorflow.python.distribute import tpu_values
|
||||
from tensorflow.python.distribute import values as distributed_values
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import composite_tensor
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.framework import type_spec
|
||||
from tensorflow.python.module import module
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
@ -260,6 +262,37 @@ class VariableTrackingTest(test_util.TensorFlowTestCase):
|
||||
m.c = aggregating
|
||||
self.assertEqual(m.variables, (mirrored, tpu, aggregating))
|
||||
|
||||
def test_composite_variable(self):
|
||||
|
||||
class Spec(type_spec.TypeSpec):
|
||||
|
||||
value_type = property(lambda self: CompositeVariable)
|
||||
|
||||
def _component_specs(self):
|
||||
pass
|
||||
|
||||
def _serialize(self):
|
||||
pass
|
||||
|
||||
def _to_components(self, value):
|
||||
return value._variables
|
||||
|
||||
def _from_components(self, variable_list):
|
||||
return CompositeVariable(variable_list)
|
||||
|
||||
class CompositeVariable(composite_tensor.CompositeTensor):
|
||||
|
||||
def __init__(self, variable_list):
|
||||
self._variables = variable_list
|
||||
|
||||
@property
|
||||
def _type_spec(self):
|
||||
return Spec()
|
||||
|
||||
m = module.Module()
|
||||
m.a = CompositeVariable([variables.Variable(1.), variables.Variable(2.)])
|
||||
self.assertAllEqual(m.variables, m.a._variables)
|
||||
|
||||
|
||||
class ModuleTrackingTest(test_util.TensorFlowTestCase):
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user