Add a clear error message when users attempt to pass multi-dimensional arrays to table_utils.
PiperOrigin-RevId: 313210934 Change-Id: Id45d84de3061efc9c1f17e6523512c4c41054e8b
This commit is contained in:
parent
db0a3952c0
commit
f684ae97cd
|
@ -21,6 +21,7 @@ import collections
|
|||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.ops import array_ops
|
||||
|
@ -60,6 +61,11 @@ class TableHandler(object):
|
|||
raise RuntimeError("Size mismatch between values and key arrays. "
|
||||
"Keys had size %s, values had size %s." %
|
||||
(len(keys), len(values)))
|
||||
keys = ops.convert_to_tensor(keys, dtype=self.table._key_dtype) # pylint: disable=protected-access
|
||||
values = ops.convert_to_tensor(values, dtype=self.table._value_dtype) # pylint: disable=protected-access
|
||||
if values.shape.ndims != 1:
|
||||
raise ValueError("`values` must be 1-dimensional, got an input with "
|
||||
" %s dimensions." % values.shape.ndims)
|
||||
self._run(self.table.insert(keys, values))
|
||||
|
||||
def _replace_oov_buckets(self, inputs, lookups):
|
||||
|
|
|
@ -108,6 +108,15 @@ class CategoricalEncodingInputTest(
|
|||
|
||||
self.assertAllEqual(expected_output, output_data)
|
||||
|
||||
def test_tensor_multi_dim_values_fails(self):
|
||||
key_data = np.array([0, 1], dtype=np.int64)
|
||||
value_data = np.array([[11, 12], [21, 22]])
|
||||
|
||||
table = get_table(dtype=dtypes.int64, oov_tokens=[1, 2])
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, "must be 1-dimensional"):
|
||||
table.insert(key_data, value_data)
|
||||
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
class CategoricalEncodingMultiOOVTest(
|
||||
|
|
Loading…
Reference in New Issue