Allow file-based initializers with integer values.
PiperOrigin-RevId: 358292431 Change-Id: Id27fc2dc0be23ef328503c73394bcbe1f0de59bc
This commit is contained in:
parent
c29e9f25e7
commit
be625d2fe8
tensorflow
@ -375,9 +375,11 @@ Status InitializeTableFromTextFile(const string& filename, int64 vocab_size,
|
|||||||
"Value index for line number requires table value dtype of int64, got ",
|
"Value index for line number requires table value dtype of int64, got ",
|
||||||
DataTypeString(table->value_dtype()));
|
DataTypeString(table->value_dtype()));
|
||||||
}
|
}
|
||||||
if (value_index == kWholeLine && value_dtype != DT_STRING) {
|
if (value_index == kWholeLine && !DataTypeIsInteger(value_dtype) &&
|
||||||
|
value_dtype != DT_STRING) {
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"Value index for whole line requires table value dtype of string, got ",
|
"Value index for whole line requires table value dtype of integer or "
|
||||||
|
"string, got ",
|
||||||
DataTypeString(table->value_dtype()));
|
DataTypeString(table->value_dtype()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -659,7 +659,7 @@ class TextFileInitializer(TableInitializerBase):
|
|||||||
- TextFileIndex.LINE_NUMBER means use the line number starting from zero,
|
- TextFileIndex.LINE_NUMBER means use the line number starting from zero,
|
||||||
expects data type int64.
|
expects data type int64.
|
||||||
- TextFileIndex.WHOLE_LINE means use the whole line content, expects data
|
- TextFileIndex.WHOLE_LINE means use the whole line content, expects data
|
||||||
type string.
|
type string or int64.
|
||||||
- A value >=0 means use the index (starting at zero) of the split line based
|
- A value >=0 means use the index (starting at zero) of the split line based
|
||||||
on `delimiter`.
|
on `delimiter`.
|
||||||
|
|
||||||
@ -712,9 +712,11 @@ class TextFileInitializer(TableInitializerBase):
|
|||||||
if value_index == TextFileIndex.LINE_NUMBER and value_dtype != dtypes.int64:
|
if value_index == TextFileIndex.LINE_NUMBER and value_dtype != dtypes.int64:
|
||||||
raise ValueError("Signature mismatch. Values must be dtype %s, got %s." %
|
raise ValueError("Signature mismatch. Values must be dtype %s, got %s." %
|
||||||
(dtypes.int64, value_dtype))
|
(dtypes.int64, value_dtype))
|
||||||
if value_index == TextFileIndex.WHOLE_LINE and value_dtype != dtypes.string:
|
if ((value_index == TextFileIndex.WHOLE_LINE) and
|
||||||
raise ValueError("Signature mismatch. Values must be dtype %s, got %s." %
|
(not value_dtype.is_integer) and (value_dtype != dtypes.string)):
|
||||||
(dtypes.string, value_dtype))
|
raise ValueError(
|
||||||
|
"Signature mismatch. Values must be integer or string, got %s." %
|
||||||
|
(value_dtype))
|
||||||
|
|
||||||
if (vocab_size is not None) and (vocab_size <= 0):
|
if (vocab_size is not None) and (vocab_size <= 0):
|
||||||
raise ValueError("Invalid vocab_size %s." % vocab_size)
|
raise ValueError("Invalid vocab_size %s." % vocab_size)
|
||||||
@ -795,7 +797,7 @@ class TextFileStringTableInitializer(TextFileInitializer):
|
|||||||
- TextFileIndex.LINE_NUMBER means use the line number starting from zero,
|
- TextFileIndex.LINE_NUMBER means use the line number starting from zero,
|
||||||
expects data type int64.
|
expects data type int64.
|
||||||
- TextFileIndex.WHOLE_LINE means use the whole line content, expects data
|
- TextFileIndex.WHOLE_LINE means use the whole line content, expects data
|
||||||
type string.
|
type string or int64.
|
||||||
- A value >=0 means use the index (starting at zero) of the split line based
|
- A value >=0 means use the index (starting at zero) of the split line based
|
||||||
on `delimiter`.
|
on `delimiter`.
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user