Set expand_composite to true when flattening nested structure of variables in keras variable assignment. It allows expanding structures that are a composite of variables.

PiperOrigin-RevId: 337230209
Change-Id: I88d3362837a02fad36afe5fc97f55d23b9cf36f5
This commit is contained in:
Chenkai Kuang 2020-10-14 20:43:31 -07:00 committed by TensorFlower Gardener
parent 8e9d65f93d
commit cd14fbf0cd
3 changed files with 48 additions and 1 deletions

View File

@ -666,6 +666,7 @@ tf_py_test(
":engine",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:composite_tensor",
"//tensorflow/python:constant_op",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:dtypes",
@ -678,6 +679,7 @@ tf_py_test(
"//tensorflow/python:summary_ops_v2",
"//tensorflow/python:tensor_array_ops",
"//tensorflow/python:tensor_spec",
"//tensorflow/python:type_spec",
"//tensorflow/python:util",
"//tensorflow/python:variables",
"//tensorflow/python/eager:context",

View File

@ -2831,7 +2831,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
# Append value to list of trainable / non-trainable weights if relevant
# TODO(b/125122625): This won't pick up on any variables added to a
# list/dict after creation.
for val in nest.flatten(value):
for val in nest.flatten(value, expand_composites=True):
# TODO(b/126450014): Remove `_UnreadVariable` check here when assign ops
# no longer return True for isinstance Variable checks.
if not isinstance(val, tf_variables.Variable):

View File

@ -27,12 +27,14 @@ import numpy as np
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import type_spec
from tensorflow.python.keras import backend
from tensorflow.python.keras import combinations
from tensorflow.python.keras import keras_parameterized
@ -433,6 +435,49 @@ class BaseLayerTest(keras_parameterized.TestCase):
# Checks that variables get initialized.
model.fit(x, y, batch_size=2, epochs=2)
@combinations.generate(combinations.combine(mode=['eager']))
def test_composite_variable_assignment(self):
class Spec(type_spec.TypeSpec):
value_type = property(lambda self: CompositeVariable)
def _component_specs(self):
pass
def _serialize(self):
pass
def _to_components(self, value):
return value._variables
def _from_components(self, variable_list):
return CompositeVariable(variable_list)
class CompositeVariable(composite_tensor.CompositeTensor):
def __init__(self, variable_list):
self._variables = variable_list
@property
def _type_spec(self):
return Spec()
class CompositeVariableLayer(base_layer.Layer):
def __init__(self):
super().__init__()
self.composite_var = CompositeVariable(
[variables.Variable(1.),
variables.Variable(2.)])
layer = CompositeVariableLayer()
self.assertLen(layer.weights, 2)
self.assertIsInstance(layer.weights[0], variables.Variable)
self.assertIsInstance(layer.weights[1], variables.Variable)
self.assertEqual(self.evaluate(layer.weights[0]), 1.)
self.assertEqual(self.evaluate(layer.weights[1]), 2.)
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_layer_names(self):
with testing_utils.use_keras_tensors_scope(False):