Add a mechanism for switching between multiple iterators by feeding a handle.

With this change, you can do the following:

1. Fetch a string handle for any iterator, by evaluating the result of
   `Iterator.string_handle()`.
2. Define an `Iterator` object based on a `tf.string` placeholder handle.
3. Feed the placeholder using an evaluated string handle to use a particular
   iterator in a particular step.

Concretely, this allows you to define two iterators for a training dataset and
a test dataset, and choose which one to use on a per-run basis:

```python
train_iterator = tf.contrib.data.Dataset(...).make_one_shot_iterator()
train_iterator_handle = sess.run(train_iterator.string_handle())

test_iterator = tf.contrib.data.Dataset(...).make_one_shot_iterator()
test_iterator_handle = sess.run(test_iterator.string_handle())

handle = tf.placeholder(tf.string, shape=[])
iterator = tf.contrib.data.Iterator.from_string_handle(
    handle, train_iterator.output_types)

next_element = iterator.get_next()
loss = f(next_element)

train_loss = sess.run(loss, feed_dict={handle: train_iterator_handle})
test_loss = sess.run(loss, feed_dict={handle: test_iterator_handle})
```
PiperOrigin-RevId: 161719836
This commit is contained in:
Derek Murray 2017-07-12 14:39:56 -07:00 committed by TensorFlower Gardener
parent 6d6dda807c
commit 71c4ec8ed6
4 changed files with 205 additions and 0 deletions

View File

@ -328,6 +328,54 @@ class IteratorTest(test.TestCase):
[1, 2, 3], dtype=dtypes.int64), constant_op.constant(
[4., 5., 6., 7.], dtype=dtypes.float64))))
def testIteratorStringHandle(self):
dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40])
iterator_3 = dataset_3.make_one_shot_iterator()
iterator_4 = dataset_4.make_one_shot_iterator()
handle_placeholder = array_ops.placeholder(dtypes.string, shape=[])
feedable_iterator = dataset_ops.Iterator.from_string_handle(
handle_placeholder, dataset_3.output_types, dataset_3.output_shapes)
next_element = feedable_iterator.get_next()
self.assertEqual(dataset_3.output_types, feedable_iterator.output_types)
self.assertEqual(dataset_4.output_types, feedable_iterator.output_types)
self.assertEqual([], feedable_iterator.output_shapes)
with self.test_session() as sess:
iterator_3_handle = sess.run(iterator_3.string_handle())
iterator_4_handle = sess.run(iterator_4.string_handle())
self.assertEqual(
10, sess.run(next_element,
feed_dict={handle_placeholder: iterator_4_handle}))
self.assertEqual(
1, sess.run(next_element,
feed_dict={handle_placeholder: iterator_3_handle}))
self.assertEqual(
20, sess.run(next_element,
feed_dict={handle_placeholder: iterator_4_handle}))
self.assertEqual(
2, sess.run(next_element,
feed_dict={handle_placeholder: iterator_3_handle}))
self.assertEqual(
30, sess.run(next_element,
feed_dict={handle_placeholder: iterator_4_handle}))
self.assertEqual(
3, sess.run(next_element,
feed_dict={handle_placeholder: iterator_3_handle}))
self.assertEqual(
40, sess.run(next_element,
feed_dict={handle_placeholder: iterator_4_handle}))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element,
feed_dict={handle_placeholder: iterator_3_handle})
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element,
feed_dict={handle_placeholder: iterator_4_handle})
if __name__ == "__main__":
test.main()

View File

@ -182,6 +182,62 @@ class Iterator(object):
output_shapes=nest.flatten(output_shapes))
return Iterator(iterator_resource, None, output_types, output_shapes)
@staticmethod
def from_string_handle(string_handle, output_types, output_shapes=None):
"""Creates a new, uninitialized `Iterator` based on the given handle.
This method allows you to define a "feedable" iterator where you can choose
between concrete iterators by feeding a value in a @{tf.Session.run} call.
In that case, `string_handle` would a @{tf.placeholder}, and you would feed
it with the value of @{tf.contrib.data.Iterator.string_handle} in each step.
For example, if you had two iterators that marked the current position in
a training dataset and a test dataset, you could choose which to use in
each step as follows:
```python
train_iterator = tf.contrib.data.Dataset(...).make_one_shot_iterator()
train_iterator_handle = sess.run(train_iterator.string_handle())
test_iterator = tf.contrib.data.Dataset(...).make_one_shot_iterator()
test_iterator_handle = sess.run(test_iterator.string_handle())
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.contrib.data.Iterator.from_string_handle(
handle, train_iterator.output_types)
next_element = iterator.get_next()
loss = f(next_element)
train_loss = sess.run(loss, feed_dict={handle: train_iterator_handle})
test_loss = sess.run(loss, feed_dict={handle: test_iterator_handle})
```
Args:
string_handle: A scalar `tf.Tensor` of type `tf.string` that evaluates
to a handle produced by the `Iterator.string_handle()` method.
output_types: A nested structure of `tf.DType` objects corresponding to
each component of an element of this iterator.
output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects
corresponding to each component of an element of this dataset. If
omitted, each component will have an unconstrainted shape.
Returns:
An `Iterator`.
"""
output_types = nest.map_structure(dtypes.as_dtype, output_types)
if output_shapes is None:
output_shapes = nest.map_structure(
lambda _: tensor_shape.TensorShape(None), output_types)
else:
output_shapes = nest.map_structure_up_to(
output_types, tensor_shape.as_shape, output_shapes)
nest.assert_same_structure(output_types, output_shapes)
string_handle = ops.convert_to_tensor(string_handle, dtype=dtypes.string)
iterator_resource = gen_dataset_ops.iterator_from_string_handle(
string_handle)
return Iterator(iterator_resource, None, output_types, output_shapes)
@property
def initializer(self):
"""A `tf.Operation` that should be run to initialize this iterator.
@ -261,6 +317,18 @@ class Iterator(object):
"""
return gen_dataset_ops.iterator_dispose(self._iterator_resource, name=name)
def string_handle(self, name=None):
"""Returns a string-valued `tf.Tensor` that represents this iterator.
Args:
name: (Optional.) A name for the created operation.
Returns:
A scalar `tf.Tensor` of type `tf.string`.
"""
return gen_dataset_ops.iterator_to_string_handle(self._iterator_resource,
name=name)
@property
def output_shapes(self):
"""Returns the shape of each component of an element of this iterator.

View File

@ -415,6 +415,69 @@ class IteratorDisposeOp : public OpKernel {
}
};
class IteratorToStringHandleOp : public OpKernel {
public:
explicit IteratorToStringHandleOp(OpKernelConstruction* ctx)
: OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
const Tensor& resource_handle_t = ctx->input(0);
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
errors::InvalidArgument("resource_handle must be a scalar"));
// Validate that the handle corresponds to a real resource, and
// that it is an IteratorResource.
IteratorResource* iterator_resource;
OP_REQUIRES_OK(
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
iterator_resource->Unref();
Tensor* string_handle_t;
OP_REQUIRES_OK(ctx,
ctx->allocate_output(0, TensorShape({}), &string_handle_t));
string_handle_t->scalar<string>()() =
resource_handle_t.scalar<ResourceHandle>()().SerializeAsString();
}
};
class IteratorFromStringHandleOp : public OpKernel {
public:
explicit IteratorFromStringHandleOp(OpKernelConstruction* ctx)
: OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
const Tensor& string_handle_t = ctx->input(0);
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(string_handle_t.shape()),
errors::InvalidArgument("string_handle must be a scalar"));
ResourceHandle resource_handle;
OP_REQUIRES(
ctx,
resource_handle.ParseFromString(string_handle_t.scalar<string>()()),
errors::InvalidArgument(
"Could not parse string_handle as a valid ResourceHandle"));
// Validate that the handle corresponds to a real resource, and
// that it is an IteratorResource.
IteratorResource* iterator_resource;
OP_REQUIRES_OK(ctx,
LookupResource(ctx, resource_handle, &iterator_resource));
iterator_resource->Unref();
OP_REQUIRES(
ctx, resource_handle.device() == ctx->device()->attributes().name(),
errors::InvalidArgument("Attempted create an iterator on device \"",
ctx->device()->attributes().name(),
"\" from handle defined on device \"",
resource_handle.device(), "\""));
Tensor* resource_handle_t;
OP_REQUIRES_OK(
ctx, ctx->allocate_output(0, TensorShape({}), &resource_handle_t));
resource_handle_t->scalar<ResourceHandle>()() = resource_handle;
}
};
REGISTER_KERNEL_BUILDER(Name("Iterator").Device(DEVICE_CPU), IteratorHandleOp);
REGISTER_KERNEL_BUILDER(Name("MakeIterator").Device(DEVICE_CPU),
MakeIteratorOp);
@ -424,6 +487,10 @@ REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE_CPU),
IteratorGetNextOp);
REGISTER_KERNEL_BUILDER(Name("IteratorDispose").Device(DEVICE_CPU),
IteratorDisposeOp);
REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle").Device(DEVICE_CPU),
IteratorToStringHandleOp);
REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandle").Device(DEVICE_CPU),
IteratorFromStringHandleOp);
} // namespace

View File

@ -533,4 +533,26 @@ REGISTER_OP("IteratorDispose")
Releases any resources used by the given iterator.
)doc");
REGISTER_OP("IteratorToStringHandle")
.Input("resource_handle: resource")
.Output("string_handle: string")
.SetShapeFn(shape_inference::ScalarShape)
.Doc(R"doc(
Converts the given `resource_handle` representing an iterator to a string.
resource_handle: A handle to an iterator resource.
string_handle: A string representation of the given handle.
)doc");
REGISTER_OP("IteratorFromStringHandle")
.Input("string_handle: string")
.Output("resource_handle: resource")
.SetShapeFn(shape_inference::ScalarShape)
.Doc(R"doc(
Converts the given string representing a handle to an iterator to a resource.
string_handle: A string representation of the given handle.
resource_handle: A handle to an iterator resource.
)doc");
} // namespace tensorflow