Add a clear error message when users attempt to pass multi-dimensional arrays to table_utils.
PiperOrigin-RevId: 313210934 Change-Id: Id45d84de3061efc9c1f17e6523512c4c41054e8b
This commit is contained in:
parent
db0a3952c0
commit
f684ae97cd
|
@ -21,6 +21,7 @@ import collections
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.keras import backend as K
|
from tensorflow.python.keras import backend as K
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
|
@ -60,6 +61,11 @@ class TableHandler(object):
|
||||||
raise RuntimeError("Size mismatch between values and key arrays. "
|
raise RuntimeError("Size mismatch between values and key arrays. "
|
||||||
"Keys had size %s, values had size %s." %
|
"Keys had size %s, values had size %s." %
|
||||||
(len(keys), len(values)))
|
(len(keys), len(values)))
|
||||||
|
keys = ops.convert_to_tensor(keys, dtype=self.table._key_dtype) # pylint: disable=protected-access
|
||||||
|
values = ops.convert_to_tensor(values, dtype=self.table._value_dtype) # pylint: disable=protected-access
|
||||||
|
if values.shape.ndims != 1:
|
||||||
|
raise ValueError("`values` must be 1-dimensional, got an input with "
|
||||||
|
" %s dimensions." % values.shape.ndims)
|
||||||
self._run(self.table.insert(keys, values))
|
self._run(self.table.insert(keys, values))
|
||||||
|
|
||||||
def _replace_oov_buckets(self, inputs, lookups):
|
def _replace_oov_buckets(self, inputs, lookups):
|
||||||
|
|
|
@ -108,6 +108,15 @@ class CategoricalEncodingInputTest(
|
||||||
|
|
||||||
self.assertAllEqual(expected_output, output_data)
|
self.assertAllEqual(expected_output, output_data)
|
||||||
|
|
||||||
|
def test_tensor_multi_dim_values_fails(self):
|
||||||
|
key_data = np.array([0, 1], dtype=np.int64)
|
||||||
|
value_data = np.array([[11, 12], [21, 22]])
|
||||||
|
|
||||||
|
table = get_table(dtype=dtypes.int64, oov_tokens=[1, 2])
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(ValueError, "must be 1-dimensional"):
|
||||||
|
table.insert(key_data, value_data)
|
||||||
|
|
||||||
|
|
||||||
@keras_parameterized.run_all_keras_modes
|
@keras_parameterized.run_all_keras_modes
|
||||||
class CategoricalEncodingMultiOOVTest(
|
class CategoricalEncodingMultiOOVTest(
|
||||||
|
|
Loading…
Reference in New Issue