Add a clear error message when users attempt to pass multi-dimensional arrays to table_utils.

PiperOrigin-RevId: 313210934
Change-Id: Id45d84de3061efc9c1f17e6523512c4c41054e8b
This commit is contained in:
A. Unique TensorFlower 2020-05-26 09:55:14 -07:00 committed by TensorFlower Gardener
parent db0a3952c0
commit f684ae97cd
2 changed files with 15 additions and 0 deletions

View File

@ -21,6 +21,7 @@ import collections
import numpy as np
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.keras import backend as K
from tensorflow.python.ops import array_ops
@ -60,6 +61,11 @@ class TableHandler(object):
raise RuntimeError("Size mismatch between values and key arrays. "
"Keys had size %s, values had size %s." %
(len(keys), len(values)))
keys = ops.convert_to_tensor(keys, dtype=self.table._key_dtype) # pylint: disable=protected-access
values = ops.convert_to_tensor(values, dtype=self.table._value_dtype) # pylint: disable=protected-access
if values.shape.ndims != 1:
raise ValueError("`values` must be 1-dimensional, got an input with "
" %s dimensions." % values.shape.ndims)
self._run(self.table.insert(keys, values))
def _replace_oov_buckets(self, inputs, lookups):

View File

@ -108,6 +108,15 @@ class CategoricalEncodingInputTest(
self.assertAllEqual(expected_output, output_data)
def test_tensor_multi_dim_values_fails(self):
key_data = np.array([0, 1], dtype=np.int64)
value_data = np.array([[11, 12], [21, 22]])
table = get_table(dtype=dtypes.int64, oov_tokens=[1, 2])
with self.assertRaisesRegexp(ValueError, "must be 1-dimensional"):
table.insert(key_data, value_data)
@keras_parameterized.run_all_keras_modes
class CategoricalEncodingMultiOOVTest(