Add a mechanism for switching between multiple iterators by feeding a handle.
With this change, you can do the following: 1. Fetch a string handle for any iterator, by evaluating the result of `Iterator.string_handle()`. 2. Define an `Iterator` object based on a `tf.string` placeholder handle. 3. Feed the placeholder using an evaluated string handle to use a particular iterator in a particular step. Concretely, this allows you to define two iterators for a training dataset and a test dataset, and choose which one to use on a per-run basis: ```python train_iterator = tf.contrib.data.Dataset(...).make_one_shot_iterator() train_iterator_handle = sess.run(train_iterator.string_handle()) test_iterator = tf.contrib.data.Dataset(...).make_one_shot_iterator() test_iterator_handle = sess.run(test_iterator.string_handle()) handle = tf.placeholder(tf.string, shape=[]) iterator = tf.contrib.data.Iterator.from_string_handle( handle, train_iterator.output_types) next_element = iterator.get_next() loss = f(next_element) train_loss = sess.run(loss, feed_dict={handle: train_iterator_handle}) test_loss = sess.run(loss, feed_dict={handle: test_iterator_handle}) ``` PiperOrigin-RevId: 161719836
This commit is contained in:
parent
6d6dda807c
commit
71c4ec8ed6
@ -328,6 +328,54 @@ class IteratorTest(test.TestCase):
|
||||
[1, 2, 3], dtype=dtypes.int64), constant_op.constant(
|
||||
[4., 5., 6., 7.], dtype=dtypes.float64))))
|
||||
|
||||
def testIteratorStringHandle(self):
|
||||
dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
|
||||
dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40])
|
||||
|
||||
iterator_3 = dataset_3.make_one_shot_iterator()
|
||||
iterator_4 = dataset_4.make_one_shot_iterator()
|
||||
|
||||
handle_placeholder = array_ops.placeholder(dtypes.string, shape=[])
|
||||
feedable_iterator = dataset_ops.Iterator.from_string_handle(
|
||||
handle_placeholder, dataset_3.output_types, dataset_3.output_shapes)
|
||||
next_element = feedable_iterator.get_next()
|
||||
|
||||
self.assertEqual(dataset_3.output_types, feedable_iterator.output_types)
|
||||
self.assertEqual(dataset_4.output_types, feedable_iterator.output_types)
|
||||
self.assertEqual([], feedable_iterator.output_shapes)
|
||||
|
||||
with self.test_session() as sess:
|
||||
iterator_3_handle = sess.run(iterator_3.string_handle())
|
||||
iterator_4_handle = sess.run(iterator_4.string_handle())
|
||||
|
||||
self.assertEqual(
|
||||
10, sess.run(next_element,
|
||||
feed_dict={handle_placeholder: iterator_4_handle}))
|
||||
self.assertEqual(
|
||||
1, sess.run(next_element,
|
||||
feed_dict={handle_placeholder: iterator_3_handle}))
|
||||
self.assertEqual(
|
||||
20, sess.run(next_element,
|
||||
feed_dict={handle_placeholder: iterator_4_handle}))
|
||||
self.assertEqual(
|
||||
2, sess.run(next_element,
|
||||
feed_dict={handle_placeholder: iterator_3_handle}))
|
||||
self.assertEqual(
|
||||
30, sess.run(next_element,
|
||||
feed_dict={handle_placeholder: iterator_4_handle}))
|
||||
self.assertEqual(
|
||||
3, sess.run(next_element,
|
||||
feed_dict={handle_placeholder: iterator_3_handle}))
|
||||
self.assertEqual(
|
||||
40, sess.run(next_element,
|
||||
feed_dict={handle_placeholder: iterator_4_handle}))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element,
|
||||
feed_dict={handle_placeholder: iterator_3_handle})
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element,
|
||||
feed_dict={handle_placeholder: iterator_4_handle})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -182,6 +182,62 @@ class Iterator(object):
|
||||
output_shapes=nest.flatten(output_shapes))
|
||||
return Iterator(iterator_resource, None, output_types, output_shapes)
|
||||
|
||||
@staticmethod
|
||||
def from_string_handle(string_handle, output_types, output_shapes=None):
|
||||
"""Creates a new, uninitialized `Iterator` based on the given handle.
|
||||
|
||||
This method allows you to define a "feedable" iterator where you can choose
|
||||
between concrete iterators by feeding a value in a @{tf.Session.run} call.
|
||||
In that case, `string_handle` would a @{tf.placeholder}, and you would feed
|
||||
it with the value of @{tf.contrib.data.Iterator.string_handle} in each step.
|
||||
|
||||
For example, if you had two iterators that marked the current position in
|
||||
a training dataset and a test dataset, you could choose which to use in
|
||||
each step as follows:
|
||||
|
||||
```python
|
||||
train_iterator = tf.contrib.data.Dataset(...).make_one_shot_iterator()
|
||||
train_iterator_handle = sess.run(train_iterator.string_handle())
|
||||
|
||||
test_iterator = tf.contrib.data.Dataset(...).make_one_shot_iterator()
|
||||
test_iterator_handle = sess.run(test_iterator.string_handle())
|
||||
|
||||
handle = tf.placeholder(tf.string, shape=[])
|
||||
iterator = tf.contrib.data.Iterator.from_string_handle(
|
||||
handle, train_iterator.output_types)
|
||||
|
||||
next_element = iterator.get_next()
|
||||
loss = f(next_element)
|
||||
|
||||
train_loss = sess.run(loss, feed_dict={handle: train_iterator_handle})
|
||||
test_loss = sess.run(loss, feed_dict={handle: test_iterator_handle})
|
||||
```
|
||||
|
||||
Args:
|
||||
string_handle: A scalar `tf.Tensor` of type `tf.string` that evaluates
|
||||
to a handle produced by the `Iterator.string_handle()` method.
|
||||
output_types: A nested structure of `tf.DType` objects corresponding to
|
||||
each component of an element of this iterator.
|
||||
output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects
|
||||
corresponding to each component of an element of this dataset. If
|
||||
omitted, each component will have an unconstrainted shape.
|
||||
|
||||
Returns:
|
||||
An `Iterator`.
|
||||
"""
|
||||
output_types = nest.map_structure(dtypes.as_dtype, output_types)
|
||||
if output_shapes is None:
|
||||
output_shapes = nest.map_structure(
|
||||
lambda _: tensor_shape.TensorShape(None), output_types)
|
||||
else:
|
||||
output_shapes = nest.map_structure_up_to(
|
||||
output_types, tensor_shape.as_shape, output_shapes)
|
||||
nest.assert_same_structure(output_types, output_shapes)
|
||||
string_handle = ops.convert_to_tensor(string_handle, dtype=dtypes.string)
|
||||
iterator_resource = gen_dataset_ops.iterator_from_string_handle(
|
||||
string_handle)
|
||||
return Iterator(iterator_resource, None, output_types, output_shapes)
|
||||
|
||||
@property
|
||||
def initializer(self):
|
||||
"""A `tf.Operation` that should be run to initialize this iterator.
|
||||
@ -261,6 +317,18 @@ class Iterator(object):
|
||||
"""
|
||||
return gen_dataset_ops.iterator_dispose(self._iterator_resource, name=name)
|
||||
|
||||
def string_handle(self, name=None):
|
||||
"""Returns a string-valued `tf.Tensor` that represents this iterator.
|
||||
|
||||
Args:
|
||||
name: (Optional.) A name for the created operation.
|
||||
|
||||
Returns:
|
||||
A scalar `tf.Tensor` of type `tf.string`.
|
||||
"""
|
||||
return gen_dataset_ops.iterator_to_string_handle(self._iterator_resource,
|
||||
name=name)
|
||||
|
||||
@property
|
||||
def output_shapes(self):
|
||||
"""Returns the shape of each component of an element of this iterator.
|
||||
|
@ -415,6 +415,69 @@ class IteratorDisposeOp : public OpKernel {
|
||||
}
|
||||
};
|
||||
|
||||
class IteratorToStringHandleOp : public OpKernel {
|
||||
public:
|
||||
explicit IteratorToStringHandleOp(OpKernelConstruction* ctx)
|
||||
: OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
const Tensor& resource_handle_t = ctx->input(0);
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
|
||||
errors::InvalidArgument("resource_handle must be a scalar"));
|
||||
|
||||
// Validate that the handle corresponds to a real resource, and
|
||||
// that it is an IteratorResource.
|
||||
IteratorResource* iterator_resource;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
|
||||
iterator_resource->Unref();
|
||||
|
||||
Tensor* string_handle_t;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->allocate_output(0, TensorShape({}), &string_handle_t));
|
||||
string_handle_t->scalar<string>()() =
|
||||
resource_handle_t.scalar<ResourceHandle>()().SerializeAsString();
|
||||
}
|
||||
};
|
||||
|
||||
class IteratorFromStringHandleOp : public OpKernel {
|
||||
public:
|
||||
explicit IteratorFromStringHandleOp(OpKernelConstruction* ctx)
|
||||
: OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
const Tensor& string_handle_t = ctx->input(0);
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(string_handle_t.shape()),
|
||||
errors::InvalidArgument("string_handle must be a scalar"));
|
||||
|
||||
ResourceHandle resource_handle;
|
||||
OP_REQUIRES(
|
||||
ctx,
|
||||
resource_handle.ParseFromString(string_handle_t.scalar<string>()()),
|
||||
errors::InvalidArgument(
|
||||
"Could not parse string_handle as a valid ResourceHandle"));
|
||||
|
||||
// Validate that the handle corresponds to a real resource, and
|
||||
// that it is an IteratorResource.
|
||||
IteratorResource* iterator_resource;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
LookupResource(ctx, resource_handle, &iterator_resource));
|
||||
iterator_resource->Unref();
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, resource_handle.device() == ctx->device()->attributes().name(),
|
||||
errors::InvalidArgument("Attempted create an iterator on device \"",
|
||||
ctx->device()->attributes().name(),
|
||||
"\" from handle defined on device \"",
|
||||
resource_handle.device(), "\""));
|
||||
|
||||
Tensor* resource_handle_t;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->allocate_output(0, TensorShape({}), &resource_handle_t));
|
||||
resource_handle_t->scalar<ResourceHandle>()() = resource_handle;
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("Iterator").Device(DEVICE_CPU), IteratorHandleOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("MakeIterator").Device(DEVICE_CPU),
|
||||
MakeIteratorOp);
|
||||
@ -424,6 +487,10 @@ REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE_CPU),
|
||||
IteratorGetNextOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("IteratorDispose").Device(DEVICE_CPU),
|
||||
IteratorDisposeOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle").Device(DEVICE_CPU),
|
||||
IteratorToStringHandleOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandle").Device(DEVICE_CPU),
|
||||
IteratorFromStringHandleOp);
|
||||
|
||||
} // namespace
|
||||
|
||||
|
@ -533,4 +533,26 @@ REGISTER_OP("IteratorDispose")
|
||||
Releases any resources used by the given iterator.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("IteratorToStringHandle")
|
||||
.Input("resource_handle: resource")
|
||||
.Output("string_handle: string")
|
||||
.SetShapeFn(shape_inference::ScalarShape)
|
||||
.Doc(R"doc(
|
||||
Converts the given `resource_handle` representing an iterator to a string.
|
||||
|
||||
resource_handle: A handle to an iterator resource.
|
||||
string_handle: A string representation of the given handle.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("IteratorFromStringHandle")
|
||||
.Input("string_handle: string")
|
||||
.Output("resource_handle: resource")
|
||||
.SetShapeFn(shape_inference::ScalarShape)
|
||||
.Doc(R"doc(
|
||||
Converts the given string representing a handle to an iterator to a resource.
|
||||
|
||||
string_handle: A string representation of the given handle.
|
||||
resource_handle: A handle to an iterator resource.
|
||||
)doc");
|
||||
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user