Allowing tf.lookup tables to be initialized from a Dataset

PiperOrigin-RevId: 298464567
Change-Id: I8f5bbc175af1f8a6c19b0febb81ba30e6fd73b47
This commit is contained in:
Rohan Jain 2020-03-02 15:38:38 -08:00 committed by TensorFlower Gardener
parent d17ab3f1d8
commit 74598809cd
15 changed files with 356 additions and 1 deletions

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "InitializeTableFromDataset"
visibility: HIDDEN
}

View File

@ -21,6 +21,7 @@ limitations under the License.
#include <string>
#include <vector>
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
@ -163,4 +164,26 @@ REGISTER_KERNEL_BUILDER(
Name("InitializeTableFromTextFileV2").Device(DEVICE_CPU),
InitializeTableFromTextFileOp);
class InitializeTableFromDatasetOp : public OpKernel {
public:
explicit InitializeTableFromDatasetOp(OpKernelConstruction* ctx)
: OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
lookup::InitializableLookupTable* table;
OP_REQUIRES_OK(ctx,
GetInitializableLookupTable("table_handle", ctx, &table));
core::ScopedUnref unref_me(table);
DatasetBase* dataset;
OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(1), &dataset));
OP_REQUIRES_OK(ctx,
lookup::InitializeTableFromDataset(ctx, dataset, table));
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(InitializeTableFromDatasetOp);
};
REGISTER_KERNEL_BUILDER(Name("InitializeTableFromDataset").Device(DEVICE_CPU),
InitializeTableFromDatasetOp);
} // namespace tensorflow

View File

@ -221,7 +221,9 @@ class HashTable : public InitializableLookupTable {
if (is_initialized()) {
return errors::Aborted("HashTable already initialized.");
}
table_.reserve(size);
if (size > 0) {
table_.reserve(size);
}
return Status::OK();
};

View File

@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/core/kernels/lookup_util.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/function_handle_cache.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/errors.h"
@ -390,5 +392,105 @@ Status InitializeTableFromTextFile(const string& filename, int64 vocab_size,
return s;
}
class DatasetIterator : public InitializableLookupTable::InitTableIterator {
public:
explicit DatasetIterator(DatasetBase* dataset) : dataset_(dataset) {}
~DatasetIterator() override {}
Status Init(OpKernelContext* ctx) {
IteratorContext::Params params(ctx);
function_handle_cache_ =
absl::make_unique<data::FunctionHandleCache>(params.flr);
params.function_handle_cache = function_handle_cache_.get();
params.resource_mgr = &resource_mgr_;
cancellation_manager_ =
absl::make_unique<CancellationManager>(ctx->cancellation_manager());
params.cancellation_manager = cancellation_manager_.get();
iterator_ctx_ = absl::make_unique<IteratorContext>(std::move(params));
TF_RETURN_IF_ERROR(dataset_->MakeIterator(iterator_ctx_.get(), nullptr,
"LookupTable", &iterator_));
Next();
return Status::OK();
}
void Next() override {
bool end_of_input;
tensors_.clear();
status_ = iterator_->GetNext(iterator_ctx_.get(), &tensors_, &end_of_input);
if (status_.ok() && end_of_input) {
status_ = errors::OutOfRange("end of iterator");
}
}
bool Valid() const override { return status_.ok(); }
const Tensor& keys() const override { return tensors_[0]; }
const Tensor& values() const override { return tensors_[1]; }
Status status() const override { return status_; }
int64 total_size() const override {
int64 size = dataset_->Cardinality();
if (size < 0) {
return 0;
}
return size;
}
private:
DatasetBase* dataset_; // not owned.
std::unique_ptr<IteratorContext> iterator_ctx_;
std::unique_ptr<data::FunctionHandleCache> function_handle_cache_;
ResourceMgr resource_mgr_;
std::unique_ptr<CancellationManager> cancellation_manager_;
std::unique_ptr<IteratorBase> iterator_;
std::vector<Tensor> tensors_;
Status status_;
};
Status InitializeTableFromDataset(OpKernelContext* ctx,
data::DatasetBase* dataset,
InitializableLookupTable* table) {
// Assert that the dataset types match up to that expected in the table.
const auto& dataset_types = dataset->output_dtypes();
if (dataset_types.size() != 2) {
return errors::InvalidArgument("Dataset should have two output types only");
}
if (dataset_types[0] != table->key_dtype()) {
return errors::InvalidArgument("Key dtype expected: ", table->key_dtype(),
" but obtained: ", dataset_types[0],
" from the dataset");
}
if (dataset_types[1] != table->value_dtype()) {
return errors::InvalidArgument(
"Value dtype expected: ", table->value_dtype(),
" but obtained: ", dataset_types[1], " from the dataset");
}
// Assert that the dataset output shapes are scalars.
const auto& dataset_shapes = dataset->output_shapes();
if (dataset_shapes.size() != 2) {
return errors::InvalidArgument(
"Dataset should have two output shapes only");
}
if (!dataset_shapes[0].IsCompatibleWith(PartialTensorShape({}))) {
return errors::InvalidArgument("Expected scalar for key. Obtained: ",
dataset_shapes[0].DebugString());
}
if (!dataset_shapes[1].IsCompatibleWith(PartialTensorShape({}))) {
return errors::InvalidArgument("Expected scalar for key. Obtained: ",
dataset_shapes[1].DebugString());
}
DatasetIterator iter(dataset);
TF_RETURN_IF_ERROR(iter.Init(ctx));
Status s = table->Initialize(iter);
if (errors::IsFailedPrecondition(s) && table->is_initialized()) {
LOG(INFO) << "Table already initialized from dataset.";
return Status::OK();
}
return s;
}
} // namespace lookup
} // namespace tensorflow

View File

@ -20,6 +20,12 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/kernels/initializable_lookup_table.h"
namespace tensorflow {
namespace data {
class DatasetBase;
} // namespace data
} // namespace tensorflow
namespace tensorflow {
namespace lookup {
@ -50,6 +56,12 @@ Status InitializeTableFromTextFile(const string& filename, int64 vocab_size,
int32 value_index, Env* env,
InitializableLookupTable* table);
// Initializes `table` from `dataset` by iterating over it. Caller retains
// ownership of `dataset`.
Status InitializeTableFromDataset(OpKernelContext* ctx,
data::DatasetBase* dataset,
InitializableLookupTable* table);
} // namespace lookup
} // namespace tensorflow

View File

@ -505,4 +505,14 @@ REGISTER_OP("InitializeTableFromTextFileV2")
return Status::OK();
});
REGISTER_OP("InitializeTableFromDataset")
.Input("table_handle: resource")
.Input("dataset: variant")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle handle;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &handle));
return Status::OK();
});
} // namespace tensorflow

View File

@ -651,6 +651,8 @@ tf_py_test(
"//tensorflow/python:lookup_ops",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:training",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:readers",
],
)

View File

@ -27,6 +27,7 @@ from tensorflow.python import tf2
from tensorflow.python.client import session
from tensorflow.python.data.experimental.ops import counter
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers as reader_ops
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
@ -42,6 +43,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import map_fn
from tensorflow.python.ops import string_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import saver
@ -488,6 +490,82 @@ class KeyValueTensorInitializerTest(BaseLookupTableTest):
self.initialize_table(table)
class DatasetInitializerTest(BaseLookupTableTest):
def _createVocabFile(self, basename, values=("brain", "salad", "surgery")):
vocabulary_file = os.path.join(self.get_temp_dir(), basename)
with open(vocabulary_file, "w") as f:
f.write("\n".join(values) + "\n")
return vocabulary_file
def test_basic(self):
keys = dataset_ops.Dataset.range(100)
values = dataset_ops.Dataset.range(100).map(
lambda x: string_ops.as_string(x * 2))
ds = dataset_ops.Dataset.zip((keys, values))
init = lookup_ops.DatasetInitializer(ds)
table = self.getHashTable()(init, default_value="")
self.initialize_table(table)
output = table.lookup(constant_op.constant([0, 2, 5], dtypes.int64))
result = self.evaluate(output)
self.assertAllEqual(["0", "4", "10"], result)
def test_basic_bad_shape(self):
keys = dataset_ops.Dataset.range(100)
values = dataset_ops.Dataset.range(100).map(
lambda x: string_ops.as_string(x * 2))
values = values.batch(4)
ds = dataset_ops.Dataset.zip((keys, values))
with self.assertRaises(ValueError):
lookup_ops.DatasetInitializer(ds)
def test_from_file(self):
vocabulary_file = self._createVocabFile("test.txt", ("one", "two", "three"))
ds = reader_ops.TextLineDataset(vocabulary_file)
ds = ds.enumerate(start=1)
init = lookup_ops.DatasetInitializer(ds)
table = self.getHashTable()(init, default_value="")
self.initialize_table(table)
output = table.lookup(constant_op.constant([2, 3, 4], dtypes.int64))
result = self.evaluate(output)
self.assertAllEqual(["two", "three", ""], result)
def test_from_multiple_files(self):
vocabulary_file1 = self._createVocabFile("test1.txt",
("one", "two", "three"))
vocabulary_file2 = self._createVocabFile("test2.txt",
("four", "five", "six"))
ds = reader_ops.TextLineDataset([vocabulary_file1, vocabulary_file2])
ds = ds.enumerate(start=1)
init = lookup_ops.DatasetInitializer(ds)
table = self.getHashTable()(init, default_value="")
self.initialize_table(table)
output = table.lookup(constant_op.constant([2, 3, 4], dtypes.int64))
result = self.evaluate(output)
self.assertAllEqual(["two", "three", "four"], result)
def test_map_variable(self):
ds = dataset_ops.Dataset.range(100)
captured_var = variables.Variable(0)
def func(_):
return captured_var.assign_add(1)
ds = ds.map(func)
ds = ds.enumerate(start=1)
init = lookup_ops.DatasetInitializer(ds)
table = self.getHashTable()(init, default_value=-1)
self.evaluate(captured_var.initializer)
self.initialize_table(table)
output = table.lookup(constant_op.constant([1, 2, 101], dtypes.int64))
result = self.evaluate(output)
self.assertAllEqual([1, 2, -1], result)
class InitializeTableFromFileOpTest(BaseLookupTableTest):
def _createVocabFile(self, basename, values=("brain", "salad", "surgery")):

View File

@ -30,6 +30,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
@ -411,6 +412,65 @@ class TableInitializerBase(trackable_base.Trackable):
return shared_name
@tf_export("lookup.experimental.DatasetInitializer")
class DatasetInitializer(TableInitializerBase):
"""Creates a table initializer from a `tf.data.Dataset`.
Sample usage:
```python
keys = tf.data.Dataset.range(100)
values = tf.data.Dataset.range(100).map(
lambda x: string_ops.as_string(x * 2))
ds = tf.data.Dataset.zip((keys, values))
init = tf.lookup.experimental.DatasetInitializer(ds)
table = tf.lookup.StaticHashTable(init, "")
output = table.lookup([0, 1, 2])
assertEquals(outputs, ["0", "2", "4"])
```
Attributes:
dataset: A `tf.data.Dataset` object that produces tuples of scalars. The
first scalar is treated as a key and the second as value.
Raises: ValueError if `dataset` doesn't conform to specifications.
"""
def __init__(self, dataset):
"""Creates a table initializser from a `tf.data.Dataset`.
Args:
dataset: A `tf.data.Dataset` object that produces tuples of scalars. The
first scalar is treated as a key and the second as value.
Raises: ValueError if `dataset` doesn't conform to specifications.
Returns: A `DatasetInitializer` object
"""
# Assert that the dataset element spec is a tuple of TensorSpecs where
# each tensor is a scalar.
self.dataset = dataset
elem_spec = self.dataset.element_spec
if len(elem_spec) != 2:
raise ValueError("element spec size should be 2")
if not isinstance(elem_spec[0], tensor_spec.TensorSpec):
raise ValueError("elem_spec[0] should be of type TensorSpec")
if not isinstance(elem_spec[1], tensor_spec.TensorSpec):
raise ValueError("elem_spec[1] should be of type TensorSpec")
if elem_spec[0].shape.rank not in (None, 0):
raise ValueError("key tensor should be a scalar")
if elem_spec[1].shape.rank not in (None, 0):
raise ValueError("value tensor should be a scalar")
key_type = elem_spec[0].dtype
value_type = elem_spec[1].dtype
super(DatasetInitializer, self).__init__(key_type, value_type)
def initialize(self, table):
_check_table_dtypes(table, self._key_dtype, self._value_dtype)
init_op = gen_lookup_ops.initialize_table_from_dataset(
table.resource_handle, self.dataset._variant_tensor) # pylint: disable=protected-access
return init_op
@tf_export("lookup.KeyValueTensorInitializer")
class KeyValueTensorInitializer(TableInitializerBase):
"""Table initializers given `keys` and `values` tensors."""

View File

@ -0,0 +1,23 @@
path: "tensorflow.lookup.experimental.DatasetInitializer"
tf_class {
is_instance: "<class \'tensorflow.python.ops.lookup_ops.DatasetInitializer\'>"
is_instance: "<class \'tensorflow.python.ops.lookup_ops.TableInitializerBase\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<type \'object\'>"
member {
name: "key_dtype"
mtype: "<type \'property\'>"
}
member {
name: "value_dtype"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "initialize"
argspec: "args=[\'self\', \'table\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -1,5 +1,9 @@
path: "tensorflow.lookup.experimental"
tf_module {
member {
name: "DatasetInitializer"
mtype: "<type \'type\'>"
}
member {
name: "DenseHashTable"
mtype: "<type \'type\'>"

View File

@ -1816,6 +1816,10 @@ tf_module {
name: "InitializeTable"
argspec: "args=[\'table_handle\', \'keys\', \'values\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "InitializeTableFromDataset"
argspec: "args=[\'table_handle\', \'dataset\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "InitializeTableFromTextFile"
argspec: "args=[\'table_handle\', \'filename\', \'key_index\', \'value_index\', \'vocab_size\', \'delimiter\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\\t\', \'None\'], "

View File

@ -0,0 +1,23 @@
path: "tensorflow.lookup.experimental.DatasetInitializer"
tf_class {
is_instance: "<class \'tensorflow.python.ops.lookup_ops.DatasetInitializer\'>"
is_instance: "<class \'tensorflow.python.ops.lookup_ops.TableInitializerBase\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<type \'object\'>"
member {
name: "key_dtype"
mtype: "<type \'property\'>"
}
member {
name: "value_dtype"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "initialize"
argspec: "args=[\'self\', \'table\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -1,5 +1,9 @@
path: "tensorflow.lookup.experimental"
tf_module {
member {
name: "DatasetInitializer"
mtype: "<type \'type\'>"
}
member {
name: "DenseHashTable"
mtype: "<type \'type\'>"

View File

@ -1816,6 +1816,10 @@ tf_module {
name: "InitializeTable"
argspec: "args=[\'table_handle\', \'keys\', \'values\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "InitializeTableFromDataset"
argspec: "args=[\'table_handle\', \'dataset\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "InitializeTableFromTextFile"
argspec: "args=[\'table_handle\', \'filename\', \'key_index\', \'value_index\', \'vocab_size\', \'delimiter\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\\t\', \'None\'], "