Add expand_composite option to tf.Module._flatten to support expanding composite tensors.

Use the said `expand_composite` option in `tf.Module.variables` like APIs to support expanding composite tensors that are a collection of variables.

PiperOrigin-RevId: 338300859
Change-Id: Iddcb1f34a87557e9de15d1887fed6d3b61319301
This commit is contained in:
Chenkai Kuang 2020-10-21 11:12:35 -07:00 committed by TensorFlower Gardener
parent 1c785764d4
commit bb41d04dd0
3 changed files with 56 additions and 6 deletions
tensorflow/python/module

View File

@ -27,9 +27,18 @@ tf_py_test(
deps = [
":module",
"//tensorflow/python:client_testlib",
"//tensorflow/python:composite_tensor",
"//tensorflow/python:extra_py_tests_deps",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:tf2",
"//tensorflow/python:type_spec",
"//tensorflow/python:variables",
"//tensorflow/python/compat:v2_compat",
"//tensorflow/python/distribute:ps_values",
"//tensorflow/python/distribute:tpu_values",
"//tensorflow/python/distribute:values",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:def_function",
"@absl_py//absl/testing:parameterized",
],
)

View File

@ -157,7 +157,7 @@ class Module(tracking.AutoTrackable):
name) followed by variables from all submodules recursively (breadth
first).
"""
return tuple(self._flatten(predicate=_is_variable))
return tuple(self._flatten(predicate=_is_variable, expand_composites=True))
@property
def trainable_variables(self):
@ -172,7 +172,8 @@ class Module(tracking.AutoTrackable):
name) followed by variables from all submodules recursively (breadth
first).
"""
return tuple(self._flatten(predicate=_is_trainable_variable))
return tuple(
self._flatten(predicate=_is_trainable_variable, expand_composites=True))
@property
def submodules(self):
@ -202,7 +203,8 @@ class Module(tracking.AutoTrackable):
recursive=True,
predicate=None,
attribute_traversal_key=None,
with_path=False):
with_path=False,
expand_composites=False):
"""Flattened attribute values in sorted order by attribute name.
Modules are flattened by first walking their attributes in name order.
@ -247,6 +249,8 @@ class Module(tracking.AutoTrackable):
as the object itself. If `with_path` is `True` then leaves will not be
de-duplicated (e.g. if the same leaf instance is reachable via multiple
modules then it will be yielded multiple times with different paths).
expand_composites: If true, then composite tensors are expanded into their
component tensors.
Returns:
Flat generator for leaves of the current module and optionally all
@ -261,7 +265,8 @@ class Module(tracking.AutoTrackable):
predicate=predicate,
attributes_to_ignore=self._TF_MODULE_IGNORED_PROPERTIES,
attribute_traversal_key=attribute_traversal_key,
with_path=with_path)
with_path=with_path,
expand_composites=expand_composites)
@classmethod
def with_name_scope(cls, method):
@ -326,6 +331,7 @@ def _flatten_module(module,
attribute_traversal_key,
attributes_to_ignore,
with_path,
expand_composites,
module_path=(),
seen=None):
"""Implementation of `flatten`."""
@ -341,7 +347,8 @@ def _flatten_module(module,
prop = module_dict[key]
try:
leaves = nest.flatten_with_tuple_paths(prop)
leaves = nest.flatten_with_tuple_paths(
prop, expand_composites=expand_composites)
except Exception as cause: # pylint: disable=broad-except
six.raise_from(
ValueError(
@ -376,6 +383,7 @@ def _flatten_module(module,
attribute_traversal_key=attribute_traversal_key,
attributes_to_ignore=submodule._TF_MODULE_IGNORED_PROPERTIES, # pylint: disable=protected-access
with_path=with_path,
expand_composites=expand_composites,
module_path=submodule_path,
seen=seen)

View File

@ -31,8 +31,10 @@ from tensorflow.python.distribute import tpu_values
from tensorflow.python.distribute import values as distributed_values
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.framework import type_spec
from tensorflow.python.module import module
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@ -260,6 +262,37 @@ class VariableTrackingTest(test_util.TensorFlowTestCase):
m.c = aggregating
self.assertEqual(m.variables, (mirrored, tpu, aggregating))
def test_composite_variable(self):
class Spec(type_spec.TypeSpec):
value_type = property(lambda self: CompositeVariable)
def _component_specs(self):
pass
def _serialize(self):
pass
def _to_components(self, value):
return value._variables
def _from_components(self, variable_list):
return CompositeVariable(variable_list)
class CompositeVariable(composite_tensor.CompositeTensor):
def __init__(self, variable_list):
self._variables = variable_list
@property
def _type_spec(self):
return Spec()
m = module.Module()
m.a = CompositeVariable([variables.Variable(1.), variables.Variable(2.)])
self.assertAllEqual(m.variables, m.a._variables)
class ModuleTrackingTest(test_util.TensorFlowTestCase):