Let CategoryEncoding error out for negative values and gives better error message.

PiperOrigin-RevId: 336751008
Change-Id: If7fb43127c2587b7658e8aed63331413ac932779
This commit is contained in:
Zhenyu Tan 2020-10-12 14:58:33 -07:00 committed by TensorFlower Gardener
parent 8bdf508eb1
commit 477cfa2aaa
2 changed files with 26 additions and 5 deletions

View File

@ -298,12 +298,18 @@ class CategoryEncoding(base_preprocessing_layer.CombinerPreprocessingLayer):
binary_output = (self._output_mode == BINARY)
if isinstance(inputs, sparse_tensor.SparseTensor):
max_value = math_ops.reduce_max(inputs.values)
min_value = math_ops.reduce_min(inputs.values)
else:
max_value = math_ops.reduce_max(inputs)
condition = math_ops.greater_equal(
math_ops.cast(out_depth, max_value.dtype), max_value)
min_value = math_ops.reduce_min(inputs)
condition = math_ops.logical_and(
math_ops.greater_equal(
math_ops.cast(out_depth, max_value.dtype), max_value),
math_ops.greater_equal(
min_value, math_ops.cast(0, min_value.dtype)))
control_flow_ops.Assert(
condition, ["Input must be less than max_token {}".format(out_depth)])
condition, ["Input values must be in the range 0 <= values < max_tokens"
" with max_tokens={}".format(out_depth)])
if self._sparse:
result = bincount_ops.sparse_bincount(
inputs,

View File

@ -277,8 +277,23 @@ class CategoryEncodingInputTest(keras_parameterized.TestCase,
int_data = encoder_layer(input_data)
self.assertAllEqual(expected_output_shape, int_data.shape.as_list())
model = keras.Model(inputs=input_data, outputs=int_data)
with self.assertRaisesRegex(errors.InvalidArgumentError,
".*must be less than max_token 3"):
with self.assertRaisesRegex(
errors.InvalidArgumentError,
".*must be in the range 0 <= values < max_tokens.*"):
_ = model.predict(input_array, steps=1)
def test_dense_negative(self):
input_array = constant_op.constant([[1, 2, 0], [2, 2, -1]])
max_tokens = 3
expected_output_shape = [None, max_tokens]
encoder_layer = get_layer_class()(max_tokens)
input_data = keras.Input(shape=(3,), dtype=dtypes.int32)
int_data = encoder_layer(input_data)
self.assertAllEqual(expected_output_shape, int_data.shape.as_list())
model = keras.Model(inputs=input_data, outputs=int_data)
with self.assertRaisesRegex(
errors.InvalidArgumentError,
".*must be in the range 0 <= values < max_tokens.*"):
_ = model.predict(input_array, steps=1)