Make sure CategoryEncoding raises an error with input values equal or greater to max_tokens
.
PiperOrigin-RevId: 342972294 Change-Id: If8f186776d784f22487ea179c8e2f6d8d962bbf3
This commit is contained in:
parent
086faf6609
commit
39d3d60407
@ -304,7 +304,7 @@ class CategoryEncoding(base_preprocessing_layer.CombinerPreprocessingLayer):
|
||||
max_value = math_ops.reduce_max(inputs)
|
||||
min_value = math_ops.reduce_min(inputs)
|
||||
condition = math_ops.logical_and(
|
||||
math_ops.greater_equal(
|
||||
math_ops.greater(
|
||||
math_ops.cast(out_depth, max_value.dtype), max_value),
|
||||
math_ops.greater_equal(
|
||||
min_value, math_ops.cast(0, min_value.dtype)))
|
||||
|
@ -269,7 +269,7 @@ class CategoryEncodingInputTest(keras_parameterized.TestCase,
|
||||
_ = model.predict(input_array, steps=1)
|
||||
|
||||
def test_dense_oov_input(self):
|
||||
input_array = constant_op.constant([[1, 2, 3], [4, 3, 4]])
|
||||
input_array = constant_op.constant([[0, 1, 2], [2, 3, 1]])
|
||||
max_tokens = 3
|
||||
expected_output_shape = [None, max_tokens]
|
||||
encoder_layer = get_layer_class()(max_tokens)
|
||||
|
Loading…
Reference in New Issue
Block a user