Allowing tf.lookup tables to be initialized from a Dataset
PiperOrigin-RevId: 298464567 Change-Id: I8f5bbc175af1f8a6c19b0febb81ba30e6fd73b47
This commit is contained in:
parent
d17ab3f1d8
commit
74598809cd
|
@ -0,0 +1,4 @@
|
|||
op {
|
||||
graph_op_name: "InitializeTableFromDataset"
|
||||
visibility: HIDDEN
|
||||
}
|
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
|
@ -163,4 +164,26 @@ REGISTER_KERNEL_BUILDER(
|
|||
Name("InitializeTableFromTextFileV2").Device(DEVICE_CPU),
|
||||
InitializeTableFromTextFileOp);
|
||||
|
||||
class InitializeTableFromDatasetOp : public OpKernel {
|
||||
public:
|
||||
explicit InitializeTableFromDatasetOp(OpKernelConstruction* ctx)
|
||||
: OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
lookup::InitializableLookupTable* table;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
GetInitializableLookupTable("table_handle", ctx, &table));
|
||||
core::ScopedUnref unref_me(table);
|
||||
DatasetBase* dataset;
|
||||
OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(1), &dataset));
|
||||
OP_REQUIRES_OK(ctx,
|
||||
lookup::InitializeTableFromDataset(ctx, dataset, table));
|
||||
}
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(InitializeTableFromDatasetOp);
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("InitializeTableFromDataset").Device(DEVICE_CPU),
|
||||
InitializeTableFromDatasetOp);
|
||||
} // namespace tensorflow
|
||||
|
|
|
@ -221,7 +221,9 @@ class HashTable : public InitializableLookupTable {
|
|||
if (is_initialized()) {
|
||||
return errors::Aborted("HashTable already initialized.");
|
||||
}
|
||||
if (size > 0) {
|
||||
table_.reserve(size);
|
||||
}
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
|
|
|
@ -15,6 +15,8 @@ limitations under the License.
|
|||
|
||||
#include "tensorflow/core/kernels/lookup_util.h"
|
||||
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/framework/function_handle_cache.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
@ -390,5 +392,105 @@ Status InitializeTableFromTextFile(const string& filename, int64 vocab_size,
|
|||
return s;
|
||||
}
|
||||
|
||||
class DatasetIterator : public InitializableLookupTable::InitTableIterator {
|
||||
public:
|
||||
explicit DatasetIterator(DatasetBase* dataset) : dataset_(dataset) {}
|
||||
|
||||
~DatasetIterator() override {}
|
||||
|
||||
Status Init(OpKernelContext* ctx) {
|
||||
IteratorContext::Params params(ctx);
|
||||
function_handle_cache_ =
|
||||
absl::make_unique<data::FunctionHandleCache>(params.flr);
|
||||
params.function_handle_cache = function_handle_cache_.get();
|
||||
params.resource_mgr = &resource_mgr_;
|
||||
cancellation_manager_ =
|
||||
absl::make_unique<CancellationManager>(ctx->cancellation_manager());
|
||||
params.cancellation_manager = cancellation_manager_.get();
|
||||
iterator_ctx_ = absl::make_unique<IteratorContext>(std::move(params));
|
||||
TF_RETURN_IF_ERROR(dataset_->MakeIterator(iterator_ctx_.get(), nullptr,
|
||||
"LookupTable", &iterator_));
|
||||
Next();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void Next() override {
|
||||
bool end_of_input;
|
||||
tensors_.clear();
|
||||
status_ = iterator_->GetNext(iterator_ctx_.get(), &tensors_, &end_of_input);
|
||||
if (status_.ok() && end_of_input) {
|
||||
status_ = errors::OutOfRange("end of iterator");
|
||||
}
|
||||
}
|
||||
|
||||
bool Valid() const override { return status_.ok(); }
|
||||
|
||||
const Tensor& keys() const override { return tensors_[0]; }
|
||||
|
||||
const Tensor& values() const override { return tensors_[1]; }
|
||||
|
||||
Status status() const override { return status_; }
|
||||
|
||||
int64 total_size() const override {
|
||||
int64 size = dataset_->Cardinality();
|
||||
if (size < 0) {
|
||||
return 0;
|
||||
}
|
||||
return size;
|
||||
}
|
||||
|
||||
private:
|
||||
DatasetBase* dataset_; // not owned.
|
||||
std::unique_ptr<IteratorContext> iterator_ctx_;
|
||||
std::unique_ptr<data::FunctionHandleCache> function_handle_cache_;
|
||||
ResourceMgr resource_mgr_;
|
||||
std::unique_ptr<CancellationManager> cancellation_manager_;
|
||||
std::unique_ptr<IteratorBase> iterator_;
|
||||
std::vector<Tensor> tensors_;
|
||||
Status status_;
|
||||
};
|
||||
|
||||
Status InitializeTableFromDataset(OpKernelContext* ctx,
|
||||
data::DatasetBase* dataset,
|
||||
InitializableLookupTable* table) {
|
||||
// Assert that the dataset types match up to that expected in the table.
|
||||
const auto& dataset_types = dataset->output_dtypes();
|
||||
if (dataset_types.size() != 2) {
|
||||
return errors::InvalidArgument("Dataset should have two output types only");
|
||||
}
|
||||
if (dataset_types[0] != table->key_dtype()) {
|
||||
return errors::InvalidArgument("Key dtype expected: ", table->key_dtype(),
|
||||
" but obtained: ", dataset_types[0],
|
||||
" from the dataset");
|
||||
}
|
||||
if (dataset_types[1] != table->value_dtype()) {
|
||||
return errors::InvalidArgument(
|
||||
"Value dtype expected: ", table->value_dtype(),
|
||||
" but obtained: ", dataset_types[1], " from the dataset");
|
||||
}
|
||||
// Assert that the dataset output shapes are scalars.
|
||||
const auto& dataset_shapes = dataset->output_shapes();
|
||||
if (dataset_shapes.size() != 2) {
|
||||
return errors::InvalidArgument(
|
||||
"Dataset should have two output shapes only");
|
||||
}
|
||||
if (!dataset_shapes[0].IsCompatibleWith(PartialTensorShape({}))) {
|
||||
return errors::InvalidArgument("Expected scalar for key. Obtained: ",
|
||||
dataset_shapes[0].DebugString());
|
||||
}
|
||||
if (!dataset_shapes[1].IsCompatibleWith(PartialTensorShape({}))) {
|
||||
return errors::InvalidArgument("Expected scalar for key. Obtained: ",
|
||||
dataset_shapes[1].DebugString());
|
||||
}
|
||||
DatasetIterator iter(dataset);
|
||||
TF_RETURN_IF_ERROR(iter.Init(ctx));
|
||||
Status s = table->Initialize(iter);
|
||||
if (errors::IsFailedPrecondition(s) && table->is_initialized()) {
|
||||
LOG(INFO) << "Table already initialized from dataset.";
|
||||
return Status::OK();
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
} // namespace lookup
|
||||
} // namespace tensorflow
|
||||
|
|
|
@ -20,6 +20,12 @@ limitations under the License.
|
|||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/kernels/initializable_lookup_table.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
class DatasetBase;
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
||||
namespace tensorflow {
|
||||
namespace lookup {
|
||||
|
||||
|
@ -50,6 +56,12 @@ Status InitializeTableFromTextFile(const string& filename, int64 vocab_size,
|
|||
int32 value_index, Env* env,
|
||||
InitializableLookupTable* table);
|
||||
|
||||
// Initializes `table` from `dataset` by iterating over it. Caller retains
|
||||
// ownership of `dataset`.
|
||||
Status InitializeTableFromDataset(OpKernelContext* ctx,
|
||||
data::DatasetBase* dataset,
|
||||
InitializableLookupTable* table);
|
||||
|
||||
} // namespace lookup
|
||||
} // namespace tensorflow
|
||||
|
||||
|
|
|
@ -505,4 +505,14 @@ REGISTER_OP("InitializeTableFromTextFileV2")
|
|||
return Status::OK();
|
||||
});
|
||||
|
||||
REGISTER_OP("InitializeTableFromDataset")
|
||||
.Input("table_handle: resource")
|
||||
.Input("dataset: variant")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
ShapeHandle handle;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &handle));
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
} // namespace tensorflow
|
||||
|
|
|
@ -651,6 +651,8 @@ tf_py_test(
|
|||
"//tensorflow/python:lookup_ops",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/ops:readers",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -27,6 +27,7 @@ from tensorflow.python import tf2
|
|||
from tensorflow.python.client import session
|
||||
from tensorflow.python.data.experimental.ops import counter
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import readers as reader_ops
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
|
@ -42,6 +43,7 @@ from tensorflow.python.ops import array_ops
|
|||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import lookup_ops
|
||||
from tensorflow.python.ops import map_fn
|
||||
from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import saver
|
||||
|
@ -488,6 +490,82 @@ class KeyValueTensorInitializerTest(BaseLookupTableTest):
|
|||
self.initialize_table(table)
|
||||
|
||||
|
||||
class DatasetInitializerTest(BaseLookupTableTest):
|
||||
|
||||
def _createVocabFile(self, basename, values=("brain", "salad", "surgery")):
|
||||
vocabulary_file = os.path.join(self.get_temp_dir(), basename)
|
||||
with open(vocabulary_file, "w") as f:
|
||||
f.write("\n".join(values) + "\n")
|
||||
return vocabulary_file
|
||||
|
||||
def test_basic(self):
|
||||
keys = dataset_ops.Dataset.range(100)
|
||||
values = dataset_ops.Dataset.range(100).map(
|
||||
lambda x: string_ops.as_string(x * 2))
|
||||
ds = dataset_ops.Dataset.zip((keys, values))
|
||||
init = lookup_ops.DatasetInitializer(ds)
|
||||
table = self.getHashTable()(init, default_value="")
|
||||
self.initialize_table(table)
|
||||
|
||||
output = table.lookup(constant_op.constant([0, 2, 5], dtypes.int64))
|
||||
result = self.evaluate(output)
|
||||
self.assertAllEqual(["0", "4", "10"], result)
|
||||
|
||||
def test_basic_bad_shape(self):
|
||||
keys = dataset_ops.Dataset.range(100)
|
||||
values = dataset_ops.Dataset.range(100).map(
|
||||
lambda x: string_ops.as_string(x * 2))
|
||||
values = values.batch(4)
|
||||
ds = dataset_ops.Dataset.zip((keys, values))
|
||||
with self.assertRaises(ValueError):
|
||||
lookup_ops.DatasetInitializer(ds)
|
||||
|
||||
def test_from_file(self):
|
||||
vocabulary_file = self._createVocabFile("test.txt", ("one", "two", "three"))
|
||||
ds = reader_ops.TextLineDataset(vocabulary_file)
|
||||
ds = ds.enumerate(start=1)
|
||||
init = lookup_ops.DatasetInitializer(ds)
|
||||
table = self.getHashTable()(init, default_value="")
|
||||
self.initialize_table(table)
|
||||
|
||||
output = table.lookup(constant_op.constant([2, 3, 4], dtypes.int64))
|
||||
result = self.evaluate(output)
|
||||
self.assertAllEqual(["two", "three", ""], result)
|
||||
|
||||
def test_from_multiple_files(self):
|
||||
vocabulary_file1 = self._createVocabFile("test1.txt",
|
||||
("one", "two", "three"))
|
||||
vocabulary_file2 = self._createVocabFile("test2.txt",
|
||||
("four", "five", "six"))
|
||||
ds = reader_ops.TextLineDataset([vocabulary_file1, vocabulary_file2])
|
||||
ds = ds.enumerate(start=1)
|
||||
init = lookup_ops.DatasetInitializer(ds)
|
||||
table = self.getHashTable()(init, default_value="")
|
||||
self.initialize_table(table)
|
||||
|
||||
output = table.lookup(constant_op.constant([2, 3, 4], dtypes.int64))
|
||||
result = self.evaluate(output)
|
||||
self.assertAllEqual(["two", "three", "four"], result)
|
||||
|
||||
def test_map_variable(self):
|
||||
ds = dataset_ops.Dataset.range(100)
|
||||
captured_var = variables.Variable(0)
|
||||
|
||||
def func(_):
|
||||
return captured_var.assign_add(1)
|
||||
|
||||
ds = ds.map(func)
|
||||
ds = ds.enumerate(start=1)
|
||||
init = lookup_ops.DatasetInitializer(ds)
|
||||
table = self.getHashTable()(init, default_value=-1)
|
||||
self.evaluate(captured_var.initializer)
|
||||
self.initialize_table(table)
|
||||
|
||||
output = table.lookup(constant_op.constant([1, 2, 101], dtypes.int64))
|
||||
result = self.evaluate(output)
|
||||
self.assertAllEqual([1, 2, -1], result)
|
||||
|
||||
|
||||
class InitializeTableFromFileOpTest(BaseLookupTableTest):
|
||||
|
||||
def _createVocabFile(self, basename, values=("brain", "salad", "surgery")):
|
||||
|
|
|
@ -30,6 +30,7 @@ from tensorflow.python.framework import dtypes
|
|||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
|
@ -411,6 +412,65 @@ class TableInitializerBase(trackable_base.Trackable):
|
|||
return shared_name
|
||||
|
||||
|
||||
@tf_export("lookup.experimental.DatasetInitializer")
|
||||
class DatasetInitializer(TableInitializerBase):
|
||||
"""Creates a table initializer from a `tf.data.Dataset`.
|
||||
|
||||
Sample usage:
|
||||
```python
|
||||
keys = tf.data.Dataset.range(100)
|
||||
values = tf.data.Dataset.range(100).map(
|
||||
lambda x: string_ops.as_string(x * 2))
|
||||
ds = tf.data.Dataset.zip((keys, values))
|
||||
init = tf.lookup.experimental.DatasetInitializer(ds)
|
||||
table = tf.lookup.StaticHashTable(init, "")
|
||||
output = table.lookup([0, 1, 2])
|
||||
assertEquals(outputs, ["0", "2", "4"])
|
||||
```
|
||||
|
||||
Attributes:
|
||||
dataset: A `tf.data.Dataset` object that produces tuples of scalars. The
|
||||
first scalar is treated as a key and the second as value.
|
||||
|
||||
Raises: ValueError if `dataset` doesn't conform to specifications.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset):
|
||||
"""Creates a table initializser from a `tf.data.Dataset`.
|
||||
|
||||
Args:
|
||||
dataset: A `tf.data.Dataset` object that produces tuples of scalars. The
|
||||
first scalar is treated as a key and the second as value.
|
||||
|
||||
Raises: ValueError if `dataset` doesn't conform to specifications.
|
||||
Returns: A `DatasetInitializer` object
|
||||
"""
|
||||
# Assert that the dataset element spec is a tuple of TensorSpecs where
|
||||
# each tensor is a scalar.
|
||||
self.dataset = dataset
|
||||
elem_spec = self.dataset.element_spec
|
||||
if len(elem_spec) != 2:
|
||||
raise ValueError("element spec size should be 2")
|
||||
if not isinstance(elem_spec[0], tensor_spec.TensorSpec):
|
||||
raise ValueError("elem_spec[0] should be of type TensorSpec")
|
||||
if not isinstance(elem_spec[1], tensor_spec.TensorSpec):
|
||||
raise ValueError("elem_spec[1] should be of type TensorSpec")
|
||||
if elem_spec[0].shape.rank not in (None, 0):
|
||||
raise ValueError("key tensor should be a scalar")
|
||||
if elem_spec[1].shape.rank not in (None, 0):
|
||||
raise ValueError("value tensor should be a scalar")
|
||||
|
||||
key_type = elem_spec[0].dtype
|
||||
value_type = elem_spec[1].dtype
|
||||
super(DatasetInitializer, self).__init__(key_type, value_type)
|
||||
|
||||
def initialize(self, table):
|
||||
_check_table_dtypes(table, self._key_dtype, self._value_dtype)
|
||||
init_op = gen_lookup_ops.initialize_table_from_dataset(
|
||||
table.resource_handle, self.dataset._variant_tensor) # pylint: disable=protected-access
|
||||
return init_op
|
||||
|
||||
|
||||
@tf_export("lookup.KeyValueTensorInitializer")
|
||||
class KeyValueTensorInitializer(TableInitializerBase):
|
||||
"""Table initializers given `keys` and `values` tensors."""
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
path: "tensorflow.lookup.experimental.DatasetInitializer"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.ops.lookup_ops.DatasetInitializer\'>"
|
||||
is_instance: "<class \'tensorflow.python.ops.lookup_ops.TableInitializerBase\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "key_dtype"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "value_dtype"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "initialize"
|
||||
argspec: "args=[\'self\', \'table\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
|
@ -1,5 +1,9 @@
|
|||
path: "tensorflow.lookup.experimental"
|
||||
tf_module {
|
||||
member {
|
||||
name: "DatasetInitializer"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "DenseHashTable"
|
||||
mtype: "<type \'type\'>"
|
||||
|
|
|
@ -1816,6 +1816,10 @@ tf_module {
|
|||
name: "InitializeTable"
|
||||
argspec: "args=[\'table_handle\', \'keys\', \'values\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "InitializeTableFromDataset"
|
||||
argspec: "args=[\'table_handle\', \'dataset\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "InitializeTableFromTextFile"
|
||||
argspec: "args=[\'table_handle\', \'filename\', \'key_index\', \'value_index\', \'vocab_size\', \'delimiter\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\\t\', \'None\'], "
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
path: "tensorflow.lookup.experimental.DatasetInitializer"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.ops.lookup_ops.DatasetInitializer\'>"
|
||||
is_instance: "<class \'tensorflow.python.ops.lookup_ops.TableInitializerBase\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "key_dtype"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "value_dtype"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "initialize"
|
||||
argspec: "args=[\'self\', \'table\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
|
@ -1,5 +1,9 @@
|
|||
path: "tensorflow.lookup.experimental"
|
||||
tf_module {
|
||||
member {
|
||||
name: "DatasetInitializer"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "DenseHashTable"
|
||||
mtype: "<type \'type\'>"
|
||||
|
|
|
@ -1816,6 +1816,10 @@ tf_module {
|
|||
name: "InitializeTable"
|
||||
argspec: "args=[\'table_handle\', \'keys\', \'values\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "InitializeTableFromDataset"
|
||||
argspec: "args=[\'table_handle\', \'dataset\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "InitializeTableFromTextFile"
|
||||
argspec: "args=[\'table_handle\', \'filename\', \'key_index\', \'value_index\', \'vocab_size\', \'delimiter\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'\\t\', \'None\'], "
|
||||
|
|
Loading…
Reference in New Issue