Let output prepserve input shape for IndexLookup

PiperOrigin-RevId: 294840562
Change-Id: Ib6d7ab94feb5a2eec13ab589a5c7e52282005425
This commit is contained in:
Zhenyu Tan 2020-02-12 22:50:43 -08:00 committed by TensorFlower Gardener
parent cdefd6af05
commit 17fd546b3f
2 changed files with 8 additions and 0 deletions

View File

@ -371,6 +371,8 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer):
inputs.dense_shape)
else:
indexed_data = table.lookup(inputs)
# (b/149446477): output does not preserve input shape.
indexed_data.set_shape(inputs.shape)
# Composite tensors can pass tensor values through, which will cause
# errors if this is the only layer in the model. To fix this, pass

View File

@ -322,6 +322,12 @@ class IndexLookupOutputTest(keras_parameterized.TestCase,
output_dataset = model.predict(input_array)
self.assertAllEqual(expected_output, output_dataset)
def test_output_shape(self):
input_data = keras.Input(shape=(4,), dtype=dtypes.string)
layer = get_layer_class()()
int_data = layer(input_data)
self.assertAllEqual(int_data.shape[1:], input_data.shape[1:])
def test_int_output_no_reserved_zero(self):
vocab_data = ["earth", "wind", "and", "fire"]
input_array = np.array([["earth", "wind", "and", "fire"],