Set expand_composite
to true when flattening nested structure of variables in keras variable assignment. It allows expanding structures that are a composite of variables.
PiperOrigin-RevId: 337230209 Change-Id: I88d3362837a02fad36afe5fc97f55d23b9cf36f5
This commit is contained in:
parent
8e9d65f93d
commit
cd14fbf0cd
@ -666,6 +666,7 @@ tf_py_test(
|
||||
":engine",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:composite_tensor",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
@ -678,6 +679,7 @@ tf_py_test(
|
||||
"//tensorflow/python:summary_ops_v2",
|
||||
"//tensorflow/python:tensor_array_ops",
|
||||
"//tensorflow/python:tensor_spec",
|
||||
"//tensorflow/python:type_spec",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/eager:context",
|
||||
|
@ -2831,7 +2831,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
# Append value to list of trainable / non-trainable weights if relevant
|
||||
# TODO(b/125122625): This won't pick up on any variables added to a
|
||||
# list/dict after creation.
|
||||
for val in nest.flatten(value):
|
||||
for val in nest.flatten(value, expand_composites=True):
|
||||
# TODO(b/126450014): Remove `_UnreadVariable` check here when assign ops
|
||||
# no longer return True for isinstance Variable checks.
|
||||
if not isinstance(val, tf_variables.Variable):
|
||||
|
@ -27,12 +27,14 @@ import numpy as np
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import composite_tensor
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import type_spec
|
||||
from tensorflow.python.keras import backend
|
||||
from tensorflow.python.keras import combinations
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
@ -433,6 +435,49 @@ class BaseLayerTest(keras_parameterized.TestCase):
|
||||
# Checks that variables get initialized.
|
||||
model.fit(x, y, batch_size=2, epochs=2)
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['eager']))
|
||||
def test_composite_variable_assignment(self):
|
||||
|
||||
class Spec(type_spec.TypeSpec):
|
||||
|
||||
value_type = property(lambda self: CompositeVariable)
|
||||
|
||||
def _component_specs(self):
|
||||
pass
|
||||
|
||||
def _serialize(self):
|
||||
pass
|
||||
|
||||
def _to_components(self, value):
|
||||
return value._variables
|
||||
|
||||
def _from_components(self, variable_list):
|
||||
return CompositeVariable(variable_list)
|
||||
|
||||
class CompositeVariable(composite_tensor.CompositeTensor):
|
||||
|
||||
def __init__(self, variable_list):
|
||||
self._variables = variable_list
|
||||
|
||||
@property
|
||||
def _type_spec(self):
|
||||
return Spec()
|
||||
|
||||
class CompositeVariableLayer(base_layer.Layer):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.composite_var = CompositeVariable(
|
||||
[variables.Variable(1.),
|
||||
variables.Variable(2.)])
|
||||
|
||||
layer = CompositeVariableLayer()
|
||||
self.assertLen(layer.weights, 2)
|
||||
self.assertIsInstance(layer.weights[0], variables.Variable)
|
||||
self.assertIsInstance(layer.weights[1], variables.Variable)
|
||||
self.assertEqual(self.evaluate(layer.weights[0]), 1.)
|
||||
self.assertEqual(self.evaluate(layer.weights[1]), 2.)
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
def test_layer_names(self):
|
||||
with testing_utils.use_keras_tensors_scope(False):
|
||||
|
Loading…
x
Reference in New Issue
Block a user