diff --git a/tensorflow/python/keras/layers/preprocessing/table_utils.py b/tensorflow/python/keras/layers/preprocessing/table_utils.py index 16ac633f8dd..cf1bfd741c9 100644 --- a/tensorflow/python/keras/layers/preprocessing/table_utils.py +++ b/tensorflow/python/keras/layers/preprocessing/table_utils.py @@ -21,6 +21,7 @@ import collections import numpy as np from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.keras import backend as K from tensorflow.python.ops import array_ops @@ -60,6 +61,11 @@ class TableHandler(object): raise RuntimeError("Size mismatch between values and key arrays. " "Keys had size %s, values had size %s." % (len(keys), len(values))) + keys = ops.convert_to_tensor(keys, dtype=self.table._key_dtype) # pylint: disable=protected-access + values = ops.convert_to_tensor(values, dtype=self.table._value_dtype) # pylint: disable=protected-access + if values.shape.ndims != 1: + raise ValueError("`values` must be 1-dimensional, got an input with " + " %s dimensions." % values.shape.ndims) self._run(self.table.insert(keys, values)) def _replace_oov_buckets(self, inputs, lookups): diff --git a/tensorflow/python/keras/layers/preprocessing/table_utils_test.py b/tensorflow/python/keras/layers/preprocessing/table_utils_test.py index 60a891f6ba8..ab7e80b628c 100644 --- a/tensorflow/python/keras/layers/preprocessing/table_utils_test.py +++ b/tensorflow/python/keras/layers/preprocessing/table_utils_test.py @@ -108,6 +108,15 @@ class CategoricalEncodingInputTest( self.assertAllEqual(expected_output, output_data) + def test_tensor_multi_dim_values_fails(self): + key_data = np.array([0, 1], dtype=np.int64) + value_data = np.array([[11, 12], [21, 22]]) + + table = get_table(dtype=dtypes.int64, oov_tokens=[1, 2]) + + with self.assertRaisesRegexp(ValueError, "must be 1-dimensional"): + table.insert(key_data, value_data) + @keras_parameterized.run_all_keras_modes class CategoricalEncodingMultiOOVTest(