From be625d2fe824b1f7e12f128b3424f570fc5084d1 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 18 Feb 2021 16:36:08 -0800 Subject: [PATCH] Allow file-based initializers with integer values. PiperOrigin-RevId: 358292431 Change-Id: Id27fc2dc0be23ef328503c73394bcbe1f0de59bc --- tensorflow/core/kernels/lookup_util.cc | 6 ++++-- tensorflow/python/ops/lookup_ops.py | 12 +++++++----- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/tensorflow/core/kernels/lookup_util.cc b/tensorflow/core/kernels/lookup_util.cc index aa39063b2dc..6dcd93c5f5b 100644 --- a/tensorflow/core/kernels/lookup_util.cc +++ b/tensorflow/core/kernels/lookup_util.cc @@ -375,9 +375,11 @@ Status InitializeTableFromTextFile(const string& filename, int64 vocab_size, "Value index for line number requires table value dtype of int64, got ", DataTypeString(table->value_dtype())); } - if (value_index == kWholeLine && value_dtype != DT_STRING) { + if (value_index == kWholeLine && !DataTypeIsInteger(value_dtype) && + value_dtype != DT_STRING) { return errors::InvalidArgument( - "Value index for whole line requires table value dtype of string, got ", + "Value index for whole line requires table value dtype of integer or " + "string, got ", DataTypeString(table->value_dtype())); } diff --git a/tensorflow/python/ops/lookup_ops.py b/tensorflow/python/ops/lookup_ops.py index c5b9eea85fb..e541a00bd96 100644 --- a/tensorflow/python/ops/lookup_ops.py +++ b/tensorflow/python/ops/lookup_ops.py @@ -659,7 +659,7 @@ class TextFileInitializer(TableInitializerBase): - TextFileIndex.LINE_NUMBER means use the line number starting from zero, expects data type int64. - TextFileIndex.WHOLE_LINE means use the whole line content, expects data - type string. + type string or int64. - A value >=0 means use the index (starting at zero) of the split line based on `delimiter`. @@ -712,9 +712,11 @@ class TextFileInitializer(TableInitializerBase): if value_index == TextFileIndex.LINE_NUMBER and value_dtype != dtypes.int64: raise ValueError("Signature mismatch. Values must be dtype %s, got %s." % (dtypes.int64, value_dtype)) - if value_index == TextFileIndex.WHOLE_LINE and value_dtype != dtypes.string: - raise ValueError("Signature mismatch. Values must be dtype %s, got %s." % - (dtypes.string, value_dtype)) + if ((value_index == TextFileIndex.WHOLE_LINE) and + (not value_dtype.is_integer) and (value_dtype != dtypes.string)): + raise ValueError( + "Signature mismatch. Values must be integer or string, got %s." % + (value_dtype)) if (vocab_size is not None) and (vocab_size <= 0): raise ValueError("Invalid vocab_size %s." % vocab_size) @@ -795,7 +797,7 @@ class TextFileStringTableInitializer(TextFileInitializer): - TextFileIndex.LINE_NUMBER means use the line number starting from zero, expects data type int64. - TextFileIndex.WHOLE_LINE means use the whole line content, expects data - type string. + type string or int64. - A value >=0 means use the index (starting at zero) of the split line based on `delimiter`.