set _batch_input_shape before build for kpl.

PiperOrigin-RevId: 317107061
Change-Id: Idd7f5dcbce9d15b2e3d708dd7ea3b2c3e5c1be7e
This commit is contained in:
Zhenyu Tan 2020-06-18 08:28:08 -07:00 committed by TensorFlower Gardener
parent 49e3e63dd7
commit 3cea671a74
2 changed files with 14 additions and 0 deletions

View File

@ -190,6 +190,8 @@ class CombinerPreprocessingLayer(PreprocessingLayer):
shape = data_element.shape
except AttributeError:
shape = None
# TODO (b/159261555): move this to base layer build.
self._batch_input_shape = shape
self.build(shape)
# Once we have built the Layer, we can process the input data. We do so

View File

@ -318,6 +318,18 @@ class NormalizationTest(keras_parameterized.TestCase,
layer.adapt(data)
self.assertAllClose(expect, layer(data))
def test_model_summary_after_layer_adapt(self):
data = np.array([[[0., 1., 2.], [0., 2., 6.]],
[[2., 3., 4.], [3., 6., 10.]]])
cls = get_layer_class()
layer = cls(axis=-1)
layer.adapt(data)
model = keras.Sequential(
[layer,
keras.layers.Dense(64, activation="relu"),
keras.layers.Dense(1)])
model.summary()
if __name__ == "__main__":
test.main()