Let output prepserve input shape for IndexLookup
PiperOrigin-RevId: 294840562 Change-Id: Ib6d7ab94feb5a2eec13ab589a5c7e52282005425
This commit is contained in:
parent
cdefd6af05
commit
17fd546b3f
@ -371,6 +371,8 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer):
|
||||
inputs.dense_shape)
|
||||
else:
|
||||
indexed_data = table.lookup(inputs)
|
||||
# (b/149446477): output does not preserve input shape.
|
||||
indexed_data.set_shape(inputs.shape)
|
||||
|
||||
# Composite tensors can pass tensor values through, which will cause
|
||||
# errors if this is the only layer in the model. To fix this, pass
|
||||
|
@ -322,6 +322,12 @@ class IndexLookupOutputTest(keras_parameterized.TestCase,
|
||||
output_dataset = model.predict(input_array)
|
||||
self.assertAllEqual(expected_output, output_dataset)
|
||||
|
||||
def test_output_shape(self):
|
||||
input_data = keras.Input(shape=(4,), dtype=dtypes.string)
|
||||
layer = get_layer_class()()
|
||||
int_data = layer(input_data)
|
||||
self.assertAllEqual(int_data.shape[1:], input_data.shape[1:])
|
||||
|
||||
def test_int_output_no_reserved_zero(self):
|
||||
vocab_data = ["earth", "wind", "and", "fire"]
|
||||
input_array = np.array([["earth", "wind", "and", "fire"],
|
||||
|
Loading…
x
Reference in New Issue
Block a user