Allow file-based initializers with integer values.

PiperOrigin-RevId: 358292431
Change-Id: Id27fc2dc0be23ef328503c73394bcbe1f0de59bc
This commit is contained in:
A. Unique TensorFlower 2021-02-18 16:36:08 -08:00 committed by TensorFlower Gardener
parent c29e9f25e7
commit be625d2fe8
2 changed files with 11 additions and 7 deletions

View File

@ -375,9 +375,11 @@ Status InitializeTableFromTextFile(const string& filename, int64 vocab_size,
"Value index for line number requires table value dtype of int64, got ",
DataTypeString(table->value_dtype()));
}
if (value_index == kWholeLine && value_dtype != DT_STRING) {
if (value_index == kWholeLine && !DataTypeIsInteger(value_dtype) &&
value_dtype != DT_STRING) {
return errors::InvalidArgument(
"Value index for whole line requires table value dtype of string, got ",
"Value index for whole line requires table value dtype of integer or "
"string, got ",
DataTypeString(table->value_dtype()));
}

View File

@ -659,7 +659,7 @@ class TextFileInitializer(TableInitializerBase):
- TextFileIndex.LINE_NUMBER means use the line number starting from zero,
expects data type int64.
- TextFileIndex.WHOLE_LINE means use the whole line content, expects data
type string.
type string or int64.
- A value >=0 means use the index (starting at zero) of the split line based
on `delimiter`.
@ -712,9 +712,11 @@ class TextFileInitializer(TableInitializerBase):
if value_index == TextFileIndex.LINE_NUMBER and value_dtype != dtypes.int64:
raise ValueError("Signature mismatch. Values must be dtype %s, got %s." %
(dtypes.int64, value_dtype))
if value_index == TextFileIndex.WHOLE_LINE and value_dtype != dtypes.string:
raise ValueError("Signature mismatch. Values must be dtype %s, got %s." %
(dtypes.string, value_dtype))
if ((value_index == TextFileIndex.WHOLE_LINE) and
(not value_dtype.is_integer) and (value_dtype != dtypes.string)):
raise ValueError(
"Signature mismatch. Values must be integer or string, got %s." %
(value_dtype))
if (vocab_size is not None) and (vocab_size <= 0):
raise ValueError("Invalid vocab_size %s." % vocab_size)
@ -795,7 +797,7 @@ class TextFileStringTableInitializer(TextFileInitializer):
- TextFileIndex.LINE_NUMBER means use the line number starting from zero,
expects data type int64.
- TextFileIndex.WHOLE_LINE means use the whole line content, expects data
type string.
type string or int64.
- A value >=0 means use the index (starting at zero) of the split line based
on `delimiter`.