set _batch_input_shape
before build for kpl.
PiperOrigin-RevId: 317107061 Change-Id: Idd7f5dcbce9d15b2e3d708dd7ea3b2c3e5c1be7e
This commit is contained in:
parent
49e3e63dd7
commit
3cea671a74
@ -190,6 +190,8 @@ class CombinerPreprocessingLayer(PreprocessingLayer):
|
||||
shape = data_element.shape
|
||||
except AttributeError:
|
||||
shape = None
|
||||
# TODO (b/159261555): move this to base layer build.
|
||||
self._batch_input_shape = shape
|
||||
self.build(shape)
|
||||
|
||||
# Once we have built the Layer, we can process the input data. We do so
|
||||
|
@ -318,6 +318,18 @@ class NormalizationTest(keras_parameterized.TestCase,
|
||||
layer.adapt(data)
|
||||
self.assertAllClose(expect, layer(data))
|
||||
|
||||
def test_model_summary_after_layer_adapt(self):
|
||||
data = np.array([[[0., 1., 2.], [0., 2., 6.]],
|
||||
[[2., 3., 4.], [3., 6., 10.]]])
|
||||
cls = get_layer_class()
|
||||
layer = cls(axis=-1)
|
||||
layer.adapt(data)
|
||||
model = keras.Sequential(
|
||||
[layer,
|
||||
keras.layers.Dense(64, activation="relu"),
|
||||
keras.layers.Dense(1)])
|
||||
model.summary()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user