From f684ae97cd895f2d150f6e41a9012c2f9a5a40e9 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 26 May 2020 09:55:14 -0700 Subject: [PATCH] Add a clear error message when users attempt to pass multi-dimensional arrays to table_utils. PiperOrigin-RevId: 313210934 Change-Id: Id45d84de3061efc9c1f17e6523512c4c41054e8b --- .../python/keras/layers/preprocessing/table_utils.py | 6 ++++++ .../keras/layers/preprocessing/table_utils_test.py | 9 +++++++++ 2 files changed, 15 insertions(+) diff --git a/tensorflow/python/keras/layers/preprocessing/table_utils.py b/tensorflow/python/keras/layers/preprocessing/table_utils.py index 16ac633f8dd..cf1bfd741c9 100644 --- a/tensorflow/python/keras/layers/preprocessing/table_utils.py +++ b/tensorflow/python/keras/layers/preprocessing/table_utils.py @@ -21,6 +21,7 @@ import collections import numpy as np from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.keras import backend as K from tensorflow.python.ops import array_ops @@ -60,6 +61,11 @@ class TableHandler(object): raise RuntimeError("Size mismatch between values and key arrays. " "Keys had size %s, values had size %s." % (len(keys), len(values))) + keys = ops.convert_to_tensor(keys, dtype=self.table._key_dtype) # pylint: disable=protected-access + values = ops.convert_to_tensor(values, dtype=self.table._value_dtype) # pylint: disable=protected-access + if values.shape.ndims != 1: + raise ValueError("`values` must be 1-dimensional, got an input with " + " %s dimensions." % values.shape.ndims) self._run(self.table.insert(keys, values)) def _replace_oov_buckets(self, inputs, lookups): diff --git a/tensorflow/python/keras/layers/preprocessing/table_utils_test.py b/tensorflow/python/keras/layers/preprocessing/table_utils_test.py index 60a891f6ba8..ab7e80b628c 100644 --- a/tensorflow/python/keras/layers/preprocessing/table_utils_test.py +++ b/tensorflow/python/keras/layers/preprocessing/table_utils_test.py @@ -108,6 +108,15 @@ class CategoricalEncodingInputTest( self.assertAllEqual(expected_output, output_data) + def test_tensor_multi_dim_values_fails(self): + key_data = np.array([0, 1], dtype=np.int64) + value_data = np.array([[11, 12], [21, 22]]) + + table = get_table(dtype=dtypes.int64, oov_tokens=[1, 2]) + + with self.assertRaisesRegexp(ValueError, "must be 1-dimensional"): + table.insert(key_data, value_data) + @keras_parameterized.run_all_keras_modes class CategoricalEncodingMultiOOVTest(