Let CategoryEncoding error out for negative values and gives better error message.
PiperOrigin-RevId: 336751008 Change-Id: If7fb43127c2587b7658e8aed63331413ac932779
This commit is contained in:
parent
8bdf508eb1
commit
477cfa2aaa
@ -298,12 +298,18 @@ class CategoryEncoding(base_preprocessing_layer.CombinerPreprocessingLayer):
|
||||
binary_output = (self._output_mode == BINARY)
|
||||
if isinstance(inputs, sparse_tensor.SparseTensor):
|
||||
max_value = math_ops.reduce_max(inputs.values)
|
||||
min_value = math_ops.reduce_min(inputs.values)
|
||||
else:
|
||||
max_value = math_ops.reduce_max(inputs)
|
||||
condition = math_ops.greater_equal(
|
||||
math_ops.cast(out_depth, max_value.dtype), max_value)
|
||||
min_value = math_ops.reduce_min(inputs)
|
||||
condition = math_ops.logical_and(
|
||||
math_ops.greater_equal(
|
||||
math_ops.cast(out_depth, max_value.dtype), max_value),
|
||||
math_ops.greater_equal(
|
||||
min_value, math_ops.cast(0, min_value.dtype)))
|
||||
control_flow_ops.Assert(
|
||||
condition, ["Input must be less than max_token {}".format(out_depth)])
|
||||
condition, ["Input values must be in the range 0 <= values < max_tokens"
|
||||
" with max_tokens={}".format(out_depth)])
|
||||
if self._sparse:
|
||||
result = bincount_ops.sparse_bincount(
|
||||
inputs,
|
||||
|
||||
@ -277,8 +277,23 @@ class CategoryEncodingInputTest(keras_parameterized.TestCase,
|
||||
int_data = encoder_layer(input_data)
|
||||
self.assertAllEqual(expected_output_shape, int_data.shape.as_list())
|
||||
model = keras.Model(inputs=input_data, outputs=int_data)
|
||||
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||
".*must be less than max_token 3"):
|
||||
with self.assertRaisesRegex(
|
||||
errors.InvalidArgumentError,
|
||||
".*must be in the range 0 <= values < max_tokens.*"):
|
||||
_ = model.predict(input_array, steps=1)
|
||||
|
||||
def test_dense_negative(self):
|
||||
input_array = constant_op.constant([[1, 2, 0], [2, 2, -1]])
|
||||
max_tokens = 3
|
||||
expected_output_shape = [None, max_tokens]
|
||||
encoder_layer = get_layer_class()(max_tokens)
|
||||
input_data = keras.Input(shape=(3,), dtype=dtypes.int32)
|
||||
int_data = encoder_layer(input_data)
|
||||
self.assertAllEqual(expected_output_shape, int_data.shape.as_list())
|
||||
model = keras.Model(inputs=input_data, outputs=int_data)
|
||||
with self.assertRaisesRegex(
|
||||
errors.InvalidArgumentError,
|
||||
".*must be in the range 0 <= values < max_tokens.*"):
|
||||
_ = model.predict(input_array, steps=1)
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user