Support partial gets in MapStagingArea (#10276)
* Modify map staging area tests - size from `small` to `medium` - introduce 2 shards * Add partial get support in MapStagingArea A partial list of tensors in a (key, value) map entry can now be requested. Once all tensors associated with the entry are removed, it is removed from the map. * Correct output/indices mismatch errors * Rename IncompleteTuple to OptionalTuple * Add partial get test with indices * Add some more index checks * Improve stage test case graph creation Test sessions (and default graphs) are reused by default. Create explicit, finalized graphs in each test to prevent possible interactions between stateful Staging Areas and others ops created in separate tests. * Make staging area tests small and remove shards They were originally made 'medium' to ameliorate timeouts in the test case, but they usually run in ~1s so they should be small. * Improve imports Avoid importing base tensorflow package * Support both python 2 and python 3 range. * Set map_stage_op_test to size=large * Convert the tests to size=medium
This commit is contained in:
parent
0df102b0a0
commit
8118ab4ec9
@ -86,13 +86,13 @@ public:
|
||||
// Public typedefs
|
||||
typedef std::vector<Tensor> Tuple;
|
||||
typedef gtl::optional<Tensor> OptionalTensor;
|
||||
typedef std::vector<OptionalTensor> IncompleteTuple;
|
||||
typedef std::vector<OptionalTensor> OptionalTuple;
|
||||
|
||||
typedef MapTraits<Ordered, Tuple> MapTraits_;
|
||||
typedef MapTraits<Ordered, OptionalTuple> MapTraits_;
|
||||
typedef typename MapTraits_::MapType MapType;
|
||||
typedef typename MapTraits_::KeyType KeyType;
|
||||
|
||||
typedef MapTraits<false, IncompleteTuple> IncompleteTraits;
|
||||
typedef MapTraits<false, OptionalTuple> IncompleteTraits;
|
||||
typedef typename IncompleteTraits::MapType IncompleteType;
|
||||
|
||||
private:
|
||||
@ -150,6 +150,16 @@ private:
|
||||
});
|
||||
}
|
||||
|
||||
// Get number of bytes in the incomplete tuple
|
||||
inline std::size_t get_tuple_bytes(const OptionalTuple & tuple)
|
||||
{
|
||||
return std::accumulate(tuple.begin(), tuple.end(), 0,
|
||||
[](const std::size_t & lhs, const OptionalTensor & rhs) {
|
||||
return lhs + rhs.has_value() ? rhs.value().TotalBytes() : 0;
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
// Check that the index is within bounds
|
||||
inline Status check_index(const Tensor & key, std::size_t index)
|
||||
{
|
||||
@ -163,12 +173,47 @@ private:
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
inline Status copy_or_move_tensors(OptionalTuple & map_tuple,
|
||||
const Tensor & key,
|
||||
const Tensor & indices,
|
||||
Tuple * output,
|
||||
bool copy=false)
|
||||
{
|
||||
auto findices = indices.flat<int>();
|
||||
|
||||
// Return values at specified indices
|
||||
for(std::size_t i = 0; i < findices.dimension(0); ++i)
|
||||
{
|
||||
std::size_t index = findices(i);
|
||||
|
||||
TF_RETURN_IF_ERROR(check_index(key, index));
|
||||
|
||||
// Insist on a value present at the specified index
|
||||
if(!map_tuple[index].has_value())
|
||||
{
|
||||
return Status(errors::InvalidArgument("Tensor at index '",
|
||||
index, "' for key '", key.scalar<int64>()(),
|
||||
"' has already been removed."));
|
||||
}
|
||||
|
||||
// Copy the contained tensor and
|
||||
// remove from the OptionalTuple
|
||||
output->push_back(map_tuple[index].value());
|
||||
|
||||
// Clear out the entry if we're not copying (moving)
|
||||
if(!copy) {
|
||||
map_tuple[index].reset();
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Check that the optional value at the specified index
|
||||
// is uninitialized
|
||||
inline Status check_index_uninitialized(const Tensor & key,
|
||||
std::size_t index,
|
||||
const IncompleteTuple & tuple)
|
||||
const OptionalTuple & tuple)
|
||||
{
|
||||
if(tuple[index].has_value())
|
||||
{
|
||||
@ -212,7 +257,7 @@ private:
|
||||
// Insert incomplete data into the Barrier
|
||||
Status put_incomplete(const KeyType & key,
|
||||
const Tensor & indices,
|
||||
Tuple * tuple,
|
||||
OptionalTuple * tuple,
|
||||
mutex_lock &l)
|
||||
{
|
||||
auto findices = indices.flat<int>();
|
||||
@ -233,10 +278,10 @@ private:
|
||||
}
|
||||
|
||||
// This key isn't present in the incomplete set
|
||||
// Create IncompleteTuple and insert
|
||||
// Create OptionalTuple and insert
|
||||
if(it == incomplete_.end())
|
||||
{
|
||||
IncompleteTuple empty(dtypes_.size());
|
||||
OptionalTuple empty(dtypes_.size());
|
||||
|
||||
// Initialize empty tuple with given dta
|
||||
for(std::size_t i = 0; i < findices.dimension(0); ++i)
|
||||
@ -260,7 +305,7 @@ private:
|
||||
else
|
||||
{
|
||||
// Reference existing incomplete tuple
|
||||
IncompleteTuple & present = it->second;
|
||||
OptionalTuple & present = it->second;
|
||||
|
||||
// Assign given data
|
||||
for(std::size_t i = 0; i < findices.dimension(0); ++i)
|
||||
@ -284,16 +329,12 @@ private:
|
||||
// If so, put the tuple in the actual map
|
||||
if(complete)
|
||||
{
|
||||
// Create a tuple for insertion
|
||||
Tuple new_tuple;
|
||||
|
||||
for(const auto & v: present)
|
||||
{ new_tuple.push_back(v.value()); }
|
||||
OptionalTuple insert_tuple = std::move(it->second);
|
||||
|
||||
// Remove from incomplete
|
||||
incomplete_.erase(it);
|
||||
|
||||
TF_RETURN_IF_ERROR(put_complete(key, &new_tuple, l));
|
||||
TF_RETURN_IF_ERROR(put_complete(key, &insert_tuple, l));
|
||||
}
|
||||
}
|
||||
|
||||
@ -301,7 +342,7 @@ private:
|
||||
}
|
||||
|
||||
// Does the insertion into the actual staging area
|
||||
Status put_complete(const KeyType & key, Tuple * tuple,
|
||||
Status put_complete(const KeyType & key, OptionalTuple * tuple,
|
||||
mutex_lock & l)
|
||||
{
|
||||
// Insert key and tuples into the map
|
||||
@ -322,7 +363,7 @@ public:
|
||||
current_bytes_(0) {}
|
||||
|
||||
Status put(KeyType* key, const Tensor * indices,
|
||||
Tuple* tuple)
|
||||
OptionalTuple* tuple)
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
|
||||
@ -362,11 +403,15 @@ public:
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status get(const KeyType* key, Tuple* tuple)
|
||||
Status get(const KeyType* key, const Tensor * indices,
|
||||
Tuple* tuple)
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
|
||||
typename MapType::const_iterator it;
|
||||
// Sanity check the indices
|
||||
TF_RETURN_IF_ERROR(check_index_ordering(*indices));
|
||||
|
||||
typename MapType::iterator it;
|
||||
|
||||
// Wait until the element with the requested key is present
|
||||
not_empty_.wait(l, [&, this]() {
|
||||
@ -374,9 +419,9 @@ public:
|
||||
return it != map_.end();
|
||||
});
|
||||
|
||||
// Copy tensors into the tuple
|
||||
for(const auto & tensor : it->second)
|
||||
{ tuple->push_back(tensor); }
|
||||
TF_RETURN_IF_ERROR(copy_or_move_tensors(it->second, *key,
|
||||
*indices, tuple,
|
||||
true));
|
||||
|
||||
// Update bytes in the Staging Area
|
||||
current_bytes_ -= get_tuple_bytes(*tuple);
|
||||
@ -384,10 +429,13 @@ public:
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status pop(const KeyType* key, Tuple* tuple)
|
||||
Status pop(const KeyType* key, const Tensor * indices, Tuple* tuple)
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
|
||||
// Sanity check the indices
|
||||
TF_RETURN_IF_ERROR(check_index_ordering(*indices));
|
||||
|
||||
typename MapType::iterator it;
|
||||
|
||||
// Wait until the element with the requested key is present
|
||||
@ -396,11 +444,16 @@ public:
|
||||
return it != this->map_.end();
|
||||
});
|
||||
|
||||
// Move from the entry as its erased anyway
|
||||
*tuple = std::move(it->second);
|
||||
TF_RETURN_IF_ERROR(copy_or_move_tensors(it->second, *key,
|
||||
*indices, tuple));
|
||||
|
||||
// Remove
|
||||
map_.erase(it);
|
||||
// Remove entry if all the values have been consumed
|
||||
bool any_left = std::any_of(it->second.begin(), it->second.end(),
|
||||
[](const OptionalTensor & T) { return T.has_value(); });
|
||||
|
||||
if(!any_left) {
|
||||
map_.erase(it);
|
||||
}
|
||||
|
||||
// Update bytes in the Staging Area
|
||||
current_bytes_ -= get_tuple_bytes(*tuple);
|
||||
@ -410,17 +463,32 @@ public:
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status popitem(KeyType* key, Tuple* tuple)
|
||||
Status popitem(KeyType* key, const Tensor * indices, Tuple* tuple)
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
|
||||
// Sanity check the indices
|
||||
TF_RETURN_IF_ERROR(check_index_ordering(*indices));
|
||||
|
||||
// Wait until map is not empty
|
||||
not_empty_.wait(l, [this]() { return !this->map_.empty(); });
|
||||
|
||||
// Move from the first element and erase it
|
||||
*tuple = std::move(map_.begin()->second);
|
||||
*key = map_.begin()->first;
|
||||
map_.erase(map_.begin());
|
||||
|
||||
auto it = map_.begin();
|
||||
|
||||
TF_RETURN_IF_ERROR(copy_or_move_tensors(it->second, *key,
|
||||
*indices, tuple));
|
||||
|
||||
*key = it->first;
|
||||
|
||||
// Remove entry if all the values have been consumed
|
||||
bool any_left = std::any_of(it->second.begin(), it->second.end(),
|
||||
[](const OptionalTensor & T) { return T.has_value(); });
|
||||
|
||||
if(!any_left) {
|
||||
map_.erase(it);
|
||||
}
|
||||
|
||||
// Update bytes in the Staging Area
|
||||
current_bytes_ -= get_tuple_bytes(*tuple);
|
||||
@ -499,7 +567,7 @@ class MapStageOp : public OpKernel
|
||||
StagingMap<Ordered>* map = nullptr;
|
||||
OP_REQUIRES_OK(ctx, GetStagingMap(ctx, def(), &map));
|
||||
core::ScopedUnref scope(map);
|
||||
typename StagingMap<Ordered>::Tuple tuple;
|
||||
typename StagingMap<Ordered>::OptionalTuple tuple;
|
||||
|
||||
const Tensor * key_tensor;
|
||||
const Tensor * indices_tensor;
|
||||
@ -560,15 +628,18 @@ class MapUnstageOp : public OpKernel
|
||||
typename StagingMap<Ordered>::Tuple tuple;
|
||||
|
||||
const Tensor * key_tensor;
|
||||
const Tensor * indices_tensor;
|
||||
OpInputList values_tensor;
|
||||
|
||||
OP_REQUIRES_OK(ctx, ctx->input("key", &key_tensor));
|
||||
OP_REQUIRES_OK(ctx, map->pop(key_tensor, &tuple));
|
||||
OP_REQUIRES_OK(ctx, ctx->input("indices", &indices_tensor));
|
||||
OP_REQUIRES_OK(ctx, map->pop(key_tensor, indices_tensor, &tuple));
|
||||
|
||||
OP_REQUIRES(ctx,
|
||||
tuple.size() == indices_tensor->NumElements(),
|
||||
errors::InvalidArgument("output/indices size mismatch: ", tuple.size(),
|
||||
" vs. ", indices_tensor->NumElements()));
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, tuple.size() == (size_t)ctx->num_outputs(),
|
||||
errors::InvalidArgument("Mismatch stage/unstage: ", tuple.size(),
|
||||
" vs. ", ctx->num_outputs()));
|
||||
for (size_t i = 0; i < tuple.size(); ++i) {
|
||||
ctx->set_output(i, tuple[i]);
|
||||
}
|
||||
@ -581,16 +652,24 @@ REGISTER_KERNEL_BUILDER(Name("OrderedMapUnstage").Device(DEVICE_CPU),
|
||||
MapUnstageOp<true>);
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER_KERNEL_BUILDER(Name("MapUnstage").HostMemory("key")
|
||||
.Device(DEVICE_GPU), MapUnstageOp<false>);
|
||||
REGISTER_KERNEL_BUILDER(Name("OrderedMapUnstage").HostMemory("key")
|
||||
.Device(DEVICE_GPU), MapUnstageOp<true>);
|
||||
REGISTER_KERNEL_BUILDER(Name("MapUnstage")
|
||||
.HostMemory("key")
|
||||
.HostMemory("indices")
|
||||
.Device(DEVICE_GPU), MapUnstageOp<false>);
|
||||
REGISTER_KERNEL_BUILDER(Name("OrderedMapUnstage")
|
||||
.HostMemory("key")
|
||||
.HostMemory("indices")
|
||||
.Device(DEVICE_GPU), MapUnstageOp<true>);
|
||||
#endif
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
REGISTER_KERNEL_BUILDER(Name("MapUnstage").HostMemory("key")
|
||||
.Device(DEVICE_SYCL), MapUnstageOp<false>);
|
||||
REGISTER_KERNEL_BUILDER(Name("OrderedMapUnstage").HostMemory("key")
|
||||
.Device(DEVICE_SYCL), MapUnstageOp<true>);
|
||||
REGISTER_KERNEL_BUILDER(Name("MapUnstage")
|
||||
.HostMemory("key")
|
||||
.HostMemory("indices")
|
||||
.Device(DEVICE_SYCL), MapUnstageOp<false>);
|
||||
REGISTER_KERNEL_BUILDER(Name("OrderedMapUnstage")
|
||||
.HostMemory("key")
|
||||
.HostMemory("indices")
|
||||
.Device(DEVICE_SYCL), MapUnstageOp<true>);
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
||||
template <bool Ordered>
|
||||
@ -608,15 +687,18 @@ class MapPeekOp : public OpKernel
|
||||
typename StagingMap<Ordered>::Tuple tuple;
|
||||
|
||||
const Tensor * key_tensor;
|
||||
const Tensor * indices_tensor;
|
||||
OpInputList values_tensor;
|
||||
|
||||
OP_REQUIRES_OK(ctx, ctx->input("key", &key_tensor));
|
||||
OP_REQUIRES_OK(ctx, map->get(key_tensor, &tuple));
|
||||
OP_REQUIRES_OK(ctx, ctx->input("indices", &indices_tensor));
|
||||
OP_REQUIRES_OK(ctx, map->get(key_tensor, indices_tensor, &tuple));
|
||||
|
||||
OP_REQUIRES(ctx,
|
||||
tuple.size() == indices_tensor->NumElements(),
|
||||
errors::InvalidArgument("output/indices size mismatch: ", tuple.size(),
|
||||
" vs. ", indices_tensor->NumElements()));
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, tuple.size() == (size_t)ctx->num_outputs(),
|
||||
errors::InvalidArgument("Mismatch stage/unstage: ", tuple.size(),
|
||||
" vs. ", ctx->num_outputs()));
|
||||
for (size_t i = 0; i < tuple.size(); ++i) {
|
||||
ctx->set_output(i, tuple[i]);
|
||||
}
|
||||
@ -629,15 +711,23 @@ REGISTER_KERNEL_BUILDER(Name("OrderedMapPeek").Device(DEVICE_CPU),
|
||||
MapPeekOp<true>);
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER_KERNEL_BUILDER(Name("MapPeek").HostMemory("key")
|
||||
REGISTER_KERNEL_BUILDER(Name("MapPeek")
|
||||
.HostMemory("key")
|
||||
.HostMemory("indices")
|
||||
.Device(DEVICE_GPU), MapPeekOp<false>);
|
||||
REGISTER_KERNEL_BUILDER(Name("OrderedMapPeek").HostMemory("key")
|
||||
REGISTER_KERNEL_BUILDER(Name("OrderedMapPeek")
|
||||
.HostMemory("key")
|
||||
.HostMemory("indices")
|
||||
.Device(DEVICE_GPU), MapPeekOp<true>);
|
||||
#endif
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
REGISTER_KERNEL_BUILDER(Name("MapPeek").HostMemory("key")
|
||||
REGISTER_KERNEL_BUILDER(Name("MapPeek")
|
||||
.HostMemory("key")
|
||||
.HostMemory("indices")
|
||||
.Device(DEVICE_SYCL), MapPeekOp<false>);
|
||||
REGISTER_KERNEL_BUILDER(Name("OrderedMapPeek").HostMemory("key")
|
||||
REGISTER_KERNEL_BUILDER(Name("OrderedMapPeek")
|
||||
.HostMemory("key")
|
||||
.HostMemory("indices")
|
||||
.Device(DEVICE_SYCL), MapPeekOp<true>);
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
||||
@ -660,18 +750,21 @@ class MapUnstageNoKeyOp : public OpKernel
|
||||
typename StagingMap<Ordered>::KeyType key;
|
||||
typename StagingMap<Ordered>::Tuple tuple;
|
||||
|
||||
OP_REQUIRES_OK(ctx, map->popitem(&key, &tuple));
|
||||
const Tensor * indices_tensor;
|
||||
|
||||
OP_REQUIRES_OK(ctx, ctx->input("indices", &indices_tensor));
|
||||
OP_REQUIRES_OK(ctx, map->popitem(&key, indices_tensor, &tuple));
|
||||
|
||||
// Allocate a key tensor and assign the key as the first output
|
||||
ctx->set_output(0, key);
|
||||
|
||||
// Set the rest of the outputs to the tuple Tensors
|
||||
OP_REQUIRES(ctx,
|
||||
tuple.size() == (size_t)ctx->num_outputs()-1,
|
||||
errors::InvalidArgument("Mismatch stage/unstage: ", tuple.size(),
|
||||
" vs. ", ctx->num_outputs()-1));
|
||||
for (size_t i = 0; i < tuple.size(); ++i)
|
||||
{
|
||||
tuple.size() == indices_tensor->NumElements(),
|
||||
errors::InvalidArgument("output/indices size mismatch: ", tuple.size(),
|
||||
" vs. ", indices_tensor->NumElements()));
|
||||
|
||||
for (size_t i = 0; i < tuple.size(); ++i) {
|
||||
ctx->set_output(i+1, tuple[i]);
|
||||
}
|
||||
}
|
||||
@ -683,16 +776,24 @@ REGISTER_KERNEL_BUILDER(Name("OrderedMapUnstageNoKey").Device(DEVICE_CPU),
|
||||
MapUnstageNoKeyOp<true>);
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER_KERNEL_BUILDER(Name("MapUnstageNoKey").HostMemory("key")
|
||||
REGISTER_KERNEL_BUILDER(Name("MapUnstageNoKey")
|
||||
.HostMemory("key")
|
||||
.HostMemory("indices")
|
||||
.Device(DEVICE_GPU), MapUnstageNoKeyOp<false>);
|
||||
REGISTER_KERNEL_BUILDER(Name("OrderedMapUnstageNoKey").HostMemory("key")
|
||||
REGISTER_KERNEL_BUILDER(Name("OrderedMapUnstageNoKey")
|
||||
.HostMemory("key")
|
||||
.HostMemory("indices")
|
||||
.Device(DEVICE_GPU), MapUnstageNoKeyOp<true>);
|
||||
|
||||
#endif
|
||||
#ifdef TENSORFLOW_USE_SYCL
|
||||
REGISTER_KERNEL_BUILDER(Name("MapUnstageNoKey").HostMemory("key")
|
||||
REGISTER_KERNEL_BUILDER(Name("MapUnstageNoKey")
|
||||
.HostMemory("key")
|
||||
.HostMemory("indices")
|
||||
.Device(DEVICE_SYCL), MapUnstageNoKeyOp<false>);
|
||||
REGISTER_KERNEL_BUILDER(Name("OrderedMapUnstageNoKey").HostMemory("key")
|
||||
REGISTER_KERNEL_BUILDER(Name("OrderedMapUnstageNoKey")
|
||||
.HostMemory("key")
|
||||
.HostMemory("indices")
|
||||
.Device(DEVICE_SYCL), MapUnstageNoKeyOp<true>);
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
||||
|
@ -2078,6 +2078,7 @@ shared_name: It is necessary to match this name to the matching Unstage Op.
|
||||
|
||||
REGISTER_OP("MapPeek")
|
||||
.Input("key: int64")
|
||||
.Input("indices: int32")
|
||||
.Output("values: dtypes")
|
||||
.Attr("capacity: int >= 0 = 0")
|
||||
.Attr("memory_limit: int >= 0 = 0")
|
||||
@ -2094,6 +2095,7 @@ this op will block until it does.
|
||||
|
||||
REGISTER_OP("MapUnstage")
|
||||
.Input("key: int64")
|
||||
.Input("indices: int32")
|
||||
.Output("values: dtypes")
|
||||
.Attr("capacity: int >= 0 = 0")
|
||||
.Attr("memory_limit: int >= 0 = 0")
|
||||
@ -2109,6 +2111,7 @@ does not contain this key, the op will block until it does.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("MapUnstageNoKey")
|
||||
.Input("indices: int32")
|
||||
.Output("key: int64")
|
||||
.Output("values: dtypes")
|
||||
.Attr("capacity: int >= 0 = 0")
|
||||
@ -2193,6 +2196,7 @@ shared_name: It is necessary to match this name to the matching Unstage Op.
|
||||
|
||||
REGISTER_OP("OrderedMapPeek")
|
||||
.Input("key: int64")
|
||||
.Input("indices: int32")
|
||||
.Output("values: dtypes")
|
||||
.Attr("capacity: int >= 0 = 0")
|
||||
.Attr("memory_limit: int >= 0 = 0")
|
||||
@ -2210,6 +2214,7 @@ performance.
|
||||
|
||||
REGISTER_OP("OrderedMapUnstage")
|
||||
.Input("key: int64")
|
||||
.Input("indices: int32")
|
||||
.Output("values: dtypes")
|
||||
.Attr("capacity: int >= 0 = 0")
|
||||
.Attr("memory_limit: int >= 0 = 0")
|
||||
@ -2225,6 +2230,7 @@ does not contain this key, the op will block until it does.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("OrderedMapUnstageNoKey")
|
||||
.Input("indices: int32")
|
||||
.Output("key: int64")
|
||||
.Output("values: dtypes")
|
||||
.Attr("capacity: int >= 0 = 0")
|
||||
|
@ -2387,12 +2387,11 @@ cuda_py_test(
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:data_flow_ops",
|
||||
],
|
||||
shard_count = 2,
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "map_stage_op_test",
|
||||
size = "small",
|
||||
size = "medium",
|
||||
srcs = ["map_stage_op_test.py"],
|
||||
additional_deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
|
@ -16,6 +16,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -27,7 +28,7 @@ from tensorflow.python.platform import test
|
||||
class MapStageTest(test.TestCase):
|
||||
|
||||
def testSimple(self):
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
with ops.Graph().as_default() as G:
|
||||
with ops.device('/cpu:0'):
|
||||
x = array_ops.placeholder(dtypes.float32)
|
||||
pi = array_ops.placeholder(dtypes.int64)
|
||||
@ -38,13 +39,17 @@ class MapStageTest(test.TestCase):
|
||||
stage = stager.put(pi, [v], [0])
|
||||
k, y = stager.get(gi)
|
||||
y = math_ops.reduce_max(math_ops.matmul(y, y))
|
||||
|
||||
G.finalize()
|
||||
|
||||
with self.test_session(use_gpu=True, graph=G) as sess:
|
||||
sess.run(stage, feed_dict={x: -1, pi: 0})
|
||||
for i in range(10):
|
||||
_, yval = sess.run([stage, y], feed_dict={x: i, pi: i+1, gi:i})
|
||||
self.assertAllClose(4 * (i - 1) * (i - 1) * 128, yval, rtol=1e-4)
|
||||
|
||||
def testMultiple(self):
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
with ops.Graph().as_default() as G:
|
||||
with ops.device('/cpu:0'):
|
||||
x = array_ops.placeholder(dtypes.float32)
|
||||
pi = array_ops.placeholder(dtypes.int64)
|
||||
@ -55,6 +60,10 @@ class MapStageTest(test.TestCase):
|
||||
stage = stager.put(pi, [x, v], [0, 1])
|
||||
k, (z, y) = stager.get(gi)
|
||||
y = math_ops.reduce_max(z * math_ops.matmul(y, y))
|
||||
|
||||
G.finalize()
|
||||
|
||||
with self.test_session(use_gpu=True, graph=G) as sess:
|
||||
sess.run(stage, feed_dict={x: -1, pi: 0})
|
||||
for i in range(10):
|
||||
_, yval = sess.run([stage, y], feed_dict={x: i, pi: i+1, gi:i})
|
||||
@ -62,7 +71,7 @@ class MapStageTest(test.TestCase):
|
||||
4 * (i - 1) * (i - 1) * (i - 1) * 128, yval, rtol=1e-4)
|
||||
|
||||
def testDictionary(self):
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
with ops.Graph().as_default() as G:
|
||||
with ops.device('/cpu:0'):
|
||||
x = array_ops.placeholder(dtypes.float32)
|
||||
pi = array_ops.placeholder(dtypes.int64)
|
||||
@ -78,6 +87,10 @@ class MapStageTest(test.TestCase):
|
||||
z = ret['x']
|
||||
y = ret['v']
|
||||
y = math_ops.reduce_max(z * math_ops.matmul(y, y))
|
||||
|
||||
G.finalize()
|
||||
|
||||
with self.test_session(use_gpu=True, graph=G) as sess:
|
||||
sess.run(stage, feed_dict={x: -1, pi: 0})
|
||||
for i in range(10):
|
||||
_, yval = sess.run([stage, y], feed_dict={x: i, pi: i+1, gi:i})
|
||||
@ -87,37 +100,43 @@ class MapStageTest(test.TestCase):
|
||||
def testColocation(self):
|
||||
gpu_dev = test.gpu_device_name()
|
||||
|
||||
with ops.device('/cpu:0'):
|
||||
x = array_ops.placeholder(dtypes.float32)
|
||||
v = 2. * (array_ops.zeros([128, 128]) + x)
|
||||
with ops.device(gpu_dev):
|
||||
stager = data_flow_ops.MapStagingArea([dtypes.float32])
|
||||
y = stager.put(1, [v], [0])
|
||||
self.assertEqual(y.device, '/device:GPU:0' if gpu_dev
|
||||
else gpu_dev)
|
||||
with ops.device('/cpu:0'):
|
||||
_, x = stager.get(1)
|
||||
y = stager.peek(1)
|
||||
_, z = stager.get()
|
||||
self.assertEqual(x.device, '/device:CPU:0')
|
||||
self.assertEqual(y.device, '/device:CPU:0')
|
||||
self.assertEqual(z.device, '/device:CPU:0')
|
||||
with ops.Graph().as_default() as G:
|
||||
with ops.device('/cpu:0'):
|
||||
x = array_ops.placeholder(dtypes.float32)
|
||||
v = 2. * (array_ops.zeros([128, 128]) + x)
|
||||
with ops.device(gpu_dev):
|
||||
stager = data_flow_ops.MapStagingArea([dtypes.float32])
|
||||
y = stager.put(1, [v], [0])
|
||||
self.assertEqual(y.device, '/device:GPU:0' if gpu_dev
|
||||
else gpu_dev)
|
||||
with ops.device('/cpu:0'):
|
||||
_, x = stager.get(1)
|
||||
y = stager.peek(1)
|
||||
_, z = stager.get()
|
||||
self.assertEqual(x.device, '/device:CPU:0')
|
||||
self.assertEqual(y.device, '/device:CPU:0')
|
||||
self.assertEqual(z.device, '/device:CPU:0')
|
||||
|
||||
G.finalize()
|
||||
|
||||
def testPeek(self):
|
||||
with ops.device('/cpu:0'):
|
||||
x = array_ops.placeholder(dtypes.int32, name='x')
|
||||
pi = array_ops.placeholder(dtypes.int64)
|
||||
gi = array_ops.placeholder(dtypes.int64)
|
||||
p = array_ops.placeholder(dtypes.int32, name='p')
|
||||
with ops.device(test.gpu_device_name()):
|
||||
stager = data_flow_ops.MapStagingArea([dtypes.int32, ], shapes=[[]])
|
||||
stage = stager.put(pi,[x], [0])
|
||||
peek = stager.peek(gi)
|
||||
size = stager.size()
|
||||
with ops.Graph().as_default() as G:
|
||||
with ops.device('/cpu:0'):
|
||||
x = array_ops.placeholder(dtypes.int32, name='x')
|
||||
pi = array_ops.placeholder(dtypes.int64)
|
||||
gi = array_ops.placeholder(dtypes.int64)
|
||||
p = array_ops.placeholder(dtypes.int32, name='p')
|
||||
with ops.device(test.gpu_device_name()):
|
||||
stager = data_flow_ops.MapStagingArea([dtypes.int32, ], shapes=[[]])
|
||||
stage = stager.put(pi,[x], [0])
|
||||
peek = stager.peek(gi)
|
||||
size = stager.size()
|
||||
|
||||
G.finalize()
|
||||
|
||||
n = 10
|
||||
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
with self.test_session(use_gpu=True, graph=G) as sess:
|
||||
for i in range(n):
|
||||
sess.run(stage, feed_dict={x:i, pi:i})
|
||||
|
||||
@ -127,21 +146,24 @@ class MapStageTest(test.TestCase):
|
||||
self.assertTrue(sess.run(size) == 10)
|
||||
|
||||
def testSizeAndClear(self):
|
||||
with ops.device('/cpu:0'):
|
||||
x = array_ops.placeholder(dtypes.float32, name='x')
|
||||
pi = array_ops.placeholder(dtypes.int64)
|
||||
gi = array_ops.placeholder(dtypes.int64)
|
||||
v = 2. * (array_ops.zeros([128, 128]) + x)
|
||||
with ops.device(test.gpu_device_name()):
|
||||
stager = data_flow_ops.MapStagingArea(
|
||||
[dtypes.float32, dtypes.float32],
|
||||
shapes=[[], [128, 128]],
|
||||
names=['x', 'v'])
|
||||
stage = stager.put(pi,{'x': x, 'v': v})
|
||||
size = stager.size()
|
||||
clear = stager.clear()
|
||||
with ops.Graph().as_default() as G:
|
||||
with ops.device('/cpu:0'):
|
||||
x = array_ops.placeholder(dtypes.float32, name='x')
|
||||
pi = array_ops.placeholder(dtypes.int64)
|
||||
gi = array_ops.placeholder(dtypes.int64)
|
||||
v = 2. * (array_ops.zeros([128, 128]) + x)
|
||||
with ops.device(test.gpu_device_name()):
|
||||
stager = data_flow_ops.MapStagingArea(
|
||||
[dtypes.float32, dtypes.float32],
|
||||
shapes=[[], [128, 128]],
|
||||
names=['x', 'v'])
|
||||
stage = stager.put(pi,{'x': x, 'v': v})
|
||||
size = stager.size()
|
||||
clear = stager.clear()
|
||||
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
G.finalize()
|
||||
|
||||
with self.test_session(use_gpu=True, graph=G) as sess:
|
||||
sess.run(stage, feed_dict={x: -1, pi: 3})
|
||||
self.assertEqual(sess.run(size), 1)
|
||||
sess.run(stage, feed_dict={x: -1, pi: 1})
|
||||
@ -153,18 +175,21 @@ class MapStageTest(test.TestCase):
|
||||
def testCapacity(self):
|
||||
capacity = 3
|
||||
|
||||
with ops.device('/cpu:0'):
|
||||
x = array_ops.placeholder(dtypes.int32, name='x')
|
||||
pi = array_ops.placeholder(dtypes.int64, name='pi')
|
||||
gi = array_ops.placeholder(dtypes.int64, name='gi')
|
||||
with ops.device(test.gpu_device_name()):
|
||||
stager = data_flow_ops.MapStagingArea([dtypes.int32, ],
|
||||
capacity=capacity, shapes=[[]])
|
||||
with ops.Graph().as_default() as G:
|
||||
with ops.device('/cpu:0'):
|
||||
x = array_ops.placeholder(dtypes.int32, name='x')
|
||||
pi = array_ops.placeholder(dtypes.int64, name='pi')
|
||||
gi = array_ops.placeholder(dtypes.int64, name='gi')
|
||||
with ops.device(test.gpu_device_name()):
|
||||
stager = data_flow_ops.MapStagingArea([dtypes.int32, ],
|
||||
capacity=capacity, shapes=[[]])
|
||||
|
||||
stage = stager.put(pi, [x], [0])
|
||||
get = stager.get()
|
||||
size = stager.size()
|
||||
|
||||
G.finalize()
|
||||
|
||||
from six.moves import queue as Queue
|
||||
import threading
|
||||
|
||||
@ -172,7 +197,7 @@ class MapStageTest(test.TestCase):
|
||||
n = 5
|
||||
missed = 0
|
||||
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
with self.test_session(use_gpu=True, graph=G) as sess:
|
||||
# Stage data in a separate thread which will block
|
||||
# when it hits the staging area's capacity and thus
|
||||
# not fill the queue with n tokens
|
||||
@ -212,16 +237,19 @@ class MapStageTest(test.TestCase):
|
||||
chunk = 200*1024 # 256K
|
||||
capacity = memory_limit // chunk
|
||||
|
||||
with ops.device('/cpu:0'):
|
||||
x = array_ops.placeholder(dtypes.uint8, name='x')
|
||||
pi = array_ops.placeholder(dtypes.int64, name='pi')
|
||||
gi = array_ops.placeholder(dtypes.int64, name='gi')
|
||||
with ops.device(test.gpu_device_name()):
|
||||
stager = data_flow_ops.MapStagingArea([dtypes.uint8],
|
||||
memory_limit=memory_limit, shapes=[[]])
|
||||
stage = stager.put(pi, [x], [0])
|
||||
get = stager.get()
|
||||
size = stager.size()
|
||||
with ops.Graph().as_default() as G:
|
||||
with ops.device('/cpu:0'):
|
||||
x = array_ops.placeholder(dtypes.uint8, name='x')
|
||||
pi = array_ops.placeholder(dtypes.int64, name='pi')
|
||||
gi = array_ops.placeholder(dtypes.int64, name='gi')
|
||||
with ops.device(test.gpu_device_name()):
|
||||
stager = data_flow_ops.MapStagingArea([dtypes.uint8],
|
||||
memory_limit=memory_limit, shapes=[[]])
|
||||
stage = stager.put(pi, [x], [0])
|
||||
get = stager.get()
|
||||
size = stager.size()
|
||||
|
||||
G.finalize()
|
||||
|
||||
from six.moves import queue as Queue
|
||||
import threading
|
||||
@ -231,7 +259,7 @@ class MapStageTest(test.TestCase):
|
||||
n = 5
|
||||
missed = 0
|
||||
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
with self.test_session(use_gpu=True, graph=G) as sess:
|
||||
# Stage data in a separate thread which will block
|
||||
# when it hits the staging area's capacity and thus
|
||||
# not fill the queue with n tokens
|
||||
@ -271,20 +299,23 @@ class MapStageTest(test.TestCase):
|
||||
import six
|
||||
import random
|
||||
|
||||
with ops.device('/cpu:0'):
|
||||
x = array_ops.placeholder(dtypes.int32, name='x')
|
||||
pi = array_ops.placeholder(dtypes.int64, name='pi')
|
||||
gi = array_ops.placeholder(dtypes.int64, name='gi')
|
||||
with ops.device(test.gpu_device_name()):
|
||||
stager = data_flow_ops.MapStagingArea([dtypes.int32, ],
|
||||
shapes=[[]], ordered=True)
|
||||
stage = stager.put(pi, [x], [0])
|
||||
get = stager.get()
|
||||
size = stager.size()
|
||||
with ops.Graph().as_default() as G:
|
||||
with ops.device('/cpu:0'):
|
||||
x = array_ops.placeholder(dtypes.int32, name='x')
|
||||
pi = array_ops.placeholder(dtypes.int64, name='pi')
|
||||
gi = array_ops.placeholder(dtypes.int64, name='gi')
|
||||
with ops.device(test.gpu_device_name()):
|
||||
stager = data_flow_ops.MapStagingArea([dtypes.int32, ],
|
||||
shapes=[[]], ordered=True)
|
||||
stage = stager.put(pi, [x], [0])
|
||||
get = stager.get()
|
||||
size = stager.size()
|
||||
|
||||
G.finalize()
|
||||
|
||||
n = 10
|
||||
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
with self.test_session(use_gpu=True, graph=G) as sess:
|
||||
# Keys n-1..0
|
||||
keys = list(reversed(six.moves.range(n)))
|
||||
|
||||
@ -300,8 +331,8 @@ class MapStageTest(test.TestCase):
|
||||
|
||||
self.assertTrue(sess.run(size) == 0)
|
||||
|
||||
def testBarrier(self):
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
def testPartialDictInsert(self):
|
||||
with ops.Graph().as_default() as G:
|
||||
with ops.device('/cpu:0'):
|
||||
x = array_ops.placeholder(dtypes.float32)
|
||||
f = array_ops.placeholder(dtypes.float32)
|
||||
@ -319,32 +350,43 @@ class MapStageTest(test.TestCase):
|
||||
size = stager.size()
|
||||
isize = stager.incomplete_size()
|
||||
|
||||
# 0 complete and incomplete entries
|
||||
self.assertTrue(sess.run([size, isize]) == [0, 0])
|
||||
# Stage key 0, x and f tuple entries
|
||||
sess.run(stage_xf, feed_dict={pi: 0, x: 1, f: 2})
|
||||
self.assertTrue(sess.run([size, isize]) == [0, 1])
|
||||
# Stage key 1, x and f tuple entries
|
||||
sess.run(stage_xf, feed_dict={pi: 1, x: 1, f: 2})
|
||||
self.assertTrue(sess.run([size, isize]) == [0, 2])
|
||||
G.finalize()
|
||||
|
||||
# Now complete key 0 with tuple entry v
|
||||
sess.run(stage_v, feed_dict={pi: 0, v: 1})
|
||||
# 1 complete and 1 incomplete entry
|
||||
self.assertTrue(sess.run([size, isize]) == [1, 1])
|
||||
# We can now obtain tuple associated with key 0
|
||||
self.assertTrue(sess.run([key, ret], feed_dict={gi:0})
|
||||
== [0, { 'x':1, 'f':2, 'v':1}])
|
||||
with self.test_session(use_gpu=True, graph=G) as sess:
|
||||
# 0 complete and incomplete entries
|
||||
self.assertTrue(sess.run([size, isize]) == [0, 0])
|
||||
# Stage key 0, x and f tuple entries
|
||||
sess.run(stage_xf, feed_dict={pi: 0, x: 1, f: 2})
|
||||
self.assertTrue(sess.run([size, isize]) == [0, 1])
|
||||
# Stage key 1, x and f tuple entries
|
||||
sess.run(stage_xf, feed_dict={pi: 1, x: 1, f: 2})
|
||||
self.assertTrue(sess.run([size, isize]) == [0, 2])
|
||||
|
||||
# 0 complete and 1 incomplete entry
|
||||
self.assertTrue(sess.run([size, isize]) == [0, 1])
|
||||
# Now complete key 1 with tuple entry v
|
||||
sess.run(stage_v, feed_dict={pi: 1, v: 3})
|
||||
# We can now obtain tuple associated with key 1
|
||||
self.assertTrue(sess.run([key, ret], feed_dict={gi:1})
|
||||
== [1, { 'x':1, 'f':2, 'v':3}])
|
||||
# Now complete key 0 with tuple entry v
|
||||
sess.run(stage_v, feed_dict={pi: 0, v: 1})
|
||||
# 1 complete and 1 incomplete entry
|
||||
self.assertTrue(sess.run([size, isize]) == [1, 1])
|
||||
# We can now obtain tuple associated with key 0
|
||||
self.assertTrue(sess.run([key, ret], feed_dict={gi:0})
|
||||
== [0, { 'x':1, 'f':2, 'v':1}])
|
||||
|
||||
# Test again with index inserts
|
||||
# 0 complete and 1 incomplete entry
|
||||
self.assertTrue(sess.run([size, isize]) == [0, 1])
|
||||
# Now complete key 1 with tuple entry v
|
||||
sess.run(stage_v, feed_dict={pi: 1, v: 3})
|
||||
# We can now obtain tuple associated with key 1
|
||||
self.assertTrue(sess.run([key, ret], feed_dict={gi:1})
|
||||
== [1, { 'x':1, 'f':2, 'v':3}])
|
||||
|
||||
def testPartialIndexInsert(self):
|
||||
with ops.Graph().as_default() as G:
|
||||
with ops.device('/cpu:0'):
|
||||
x = array_ops.placeholder(dtypes.float32)
|
||||
f = array_ops.placeholder(dtypes.float32)
|
||||
v = array_ops.placeholder(dtypes.float32)
|
||||
pi = array_ops.placeholder(dtypes.int64)
|
||||
gi = array_ops.placeholder(dtypes.int64)
|
||||
with ops.device(test.gpu_device_name()):
|
||||
stager = data_flow_ops.MapStagingArea(
|
||||
[dtypes.float32, dtypes.float32, dtypes.float32])
|
||||
stage_xf = stager.put(pi, [x, f], [0, 2])
|
||||
@ -353,31 +395,162 @@ class MapStageTest(test.TestCase):
|
||||
size = stager.size()
|
||||
isize = stager.incomplete_size()
|
||||
|
||||
# 0 complete and incomplete entries
|
||||
self.assertTrue(sess.run([size, isize]) == [0, 0])
|
||||
# Stage key 0, x and f tuple entries
|
||||
sess.run(stage_xf, feed_dict={pi: 0, x: 1, f: 2})
|
||||
self.assertTrue(sess.run([size, isize]) == [0, 1])
|
||||
# Stage key 1, x and f tuple entries
|
||||
sess.run(stage_xf, feed_dict={pi: 1, x: 1, f: 2})
|
||||
self.assertTrue(sess.run([size, isize]) == [0, 2])
|
||||
G.finalize()
|
||||
|
||||
# Now complete key 0 with tuple entry v
|
||||
sess.run(stage_v, feed_dict={pi: 0, v: 1})
|
||||
# 1 complete and 1 incomplete entry
|
||||
self.assertTrue(sess.run([size, isize]) == [1, 1])
|
||||
# We can now obtain tuple associated with key 0
|
||||
self.assertTrue(sess.run([key, ret], feed_dict={gi:0})
|
||||
== [0, [1, 1, 2]])
|
||||
with self.test_session(use_gpu=True, graph=G) as sess:
|
||||
# 0 complete and incomplete entries
|
||||
self.assertTrue(sess.run([size, isize]) == [0, 0])
|
||||
# Stage key 0, x and f tuple entries
|
||||
sess.run(stage_xf, feed_dict={pi: 0, x: 1, f: 2})
|
||||
self.assertTrue(sess.run([size, isize]) == [0, 1])
|
||||
# Stage key 1, x and f tuple entries
|
||||
sess.run(stage_xf, feed_dict={pi: 1, x: 1, f: 2})
|
||||
self.assertTrue(sess.run([size, isize]) == [0, 2])
|
||||
|
||||
# 0 complete and 1 incomplete entry
|
||||
self.assertTrue(sess.run([size, isize]) == [0, 1])
|
||||
# Now complete key 1 with tuple entry v
|
||||
sess.run(stage_v, feed_dict={pi: 1, v: 3})
|
||||
# We can now obtain tuple associated with key 1
|
||||
self.assertTrue(sess.run([key, ret], feed_dict={gi:1})
|
||||
== [1, [1,3, 2]])
|
||||
# Now complete key 0 with tuple entry v
|
||||
sess.run(stage_v, feed_dict={pi: 0, v: 1})
|
||||
# 1 complete and 1 incomplete entry
|
||||
self.assertTrue(sess.run([size, isize]) == [1, 1])
|
||||
# We can now obtain tuple associated with key 0
|
||||
self.assertTrue(sess.run([key, ret], feed_dict={gi:0})
|
||||
== [0, [1, 1, 2]])
|
||||
|
||||
# 0 complete and 1 incomplete entry
|
||||
self.assertTrue(sess.run([size, isize]) == [0, 1])
|
||||
# Now complete key 1 with tuple entry v
|
||||
sess.run(stage_v, feed_dict={pi: 1, v: 3})
|
||||
# We can now obtain tuple associated with key 1
|
||||
self.assertTrue(sess.run([key, ret], feed_dict={gi:1})
|
||||
== [1, [1,3, 2]])
|
||||
|
||||
def testPartialDictGetsAndPeeks(self):
|
||||
with ops.Graph().as_default() as G:
|
||||
with ops.device('/cpu:0'):
|
||||
x = array_ops.placeholder(dtypes.float32)
|
||||
f = array_ops.placeholder(dtypes.float32)
|
||||
v = array_ops.placeholder(dtypes.float32)
|
||||
pi = array_ops.placeholder(dtypes.int64)
|
||||
pei = array_ops.placeholder(dtypes.int64)
|
||||
gi = array_ops.placeholder(dtypes.int64)
|
||||
with ops.device(test.gpu_device_name()):
|
||||
# Test barrier with dictionary
|
||||
stager = data_flow_ops.MapStagingArea(
|
||||
[dtypes.float32, dtypes.float32, dtypes.float32],
|
||||
names=['x', 'v', 'f'])
|
||||
stage_xf = stager.put(pi,{'x': x, 'f': f})
|
||||
stage_v = stager.put(pi, {'v': v})
|
||||
peek_xf = stager.peek(pei, ['x', 'f'])
|
||||
peek_v = stager.peek(pei, ['v'])
|
||||
key_xf, get_xf = stager.get(gi, ['x', 'f'])
|
||||
key_v, get_v = stager.get(gi, ['v'])
|
||||
pop_key_xf, pop_xf = stager.get(indices=['x', 'f'])
|
||||
pop_key_v, pop_v = stager.get(pi, ['v'])
|
||||
size = stager.size()
|
||||
isize = stager.incomplete_size()
|
||||
|
||||
G.finalize()
|
||||
|
||||
with self.test_session(use_gpu=True, graph=G) as sess:
|
||||
# 0 complete and incomplete entries
|
||||
self.assertTrue(sess.run([size, isize]) == [0, 0])
|
||||
# Stage key 0, x and f tuple entries
|
||||
sess.run(stage_xf, feed_dict={pi: 0, x: 1, f: 2})
|
||||
self.assertTrue(sess.run([size, isize]) == [0, 1])
|
||||
# Stage key 1, x and f tuple entries
|
||||
sess.run(stage_xf, feed_dict={pi: 1, x: 1, f: 2})
|
||||
self.assertTrue(sess.run([size, isize]) == [0, 2])
|
||||
|
||||
# Now complete key 0 with tuple entry v
|
||||
sess.run(stage_v, feed_dict={pi: 0, v: 1})
|
||||
# 1 complete and 1 incomplete entry
|
||||
self.assertTrue(sess.run([size, isize]) == [1, 1])
|
||||
|
||||
# We can now peek at 'x' and 'f' values associated with key 0
|
||||
self.assertTrue(sess.run(peek_xf, feed_dict={pei:0})
|
||||
== { 'x':1, 'f':2})
|
||||
# Peek at 'v' value associated with key 0
|
||||
self.assertTrue(sess.run(peek_v, feed_dict={pei:0})
|
||||
== { 'v':1})
|
||||
# 1 complete and 1 incomplete entry
|
||||
self.assertTrue(sess.run([size, isize]) == [1, 1])
|
||||
|
||||
# We can now obtain 'x' and 'f' values associated with key 0
|
||||
self.assertTrue(sess.run([key_xf, get_xf], feed_dict={gi:0})
|
||||
== [0, { 'x':1, 'f':2}])
|
||||
# Still have 1 complete and 1 incomplete entry
|
||||
self.assertTrue(sess.run([size, isize]) == [1, 1])
|
||||
|
||||
# We can no longer get 'x' and 'f' from key 0
|
||||
with self.assertRaises(errors.InvalidArgumentError) as cm:
|
||||
sess.run([key_xf, get_xf], feed_dict={gi:0})
|
||||
|
||||
exc_str = ("Tensor at index '0' for key '0' "
|
||||
"has already been removed.")
|
||||
|
||||
self.assertTrue(exc_str in cm.exception.message)
|
||||
|
||||
# Obtain 'v' value associated with key 0
|
||||
self.assertTrue(sess.run([key_v, get_v], feed_dict={gi:0})
|
||||
== [0, { 'v':1}])
|
||||
# 0 complete and 1 incomplete entry
|
||||
self.assertTrue(sess.run([size, isize]) == [0, 1])
|
||||
|
||||
# Now complete key 1 with tuple entry v
|
||||
sess.run(stage_v, feed_dict={pi: 1, v: 1})
|
||||
# 1 complete and 1 incomplete entry
|
||||
self.assertTrue(sess.run([size, isize]) == [1, 0])
|
||||
|
||||
# Pop without key to obtain 'x' and 'f' values associated with key 1
|
||||
self.assertTrue(sess.run([pop_key_xf, pop_xf])
|
||||
== [1, { 'x':1, 'f':2}])
|
||||
# still 1 complete and 1 incomplete entry
|
||||
self.assertTrue(sess.run([size, isize]) == [1, 0])
|
||||
# We can now obtain 'x' and 'f' values associated with key 1
|
||||
self.assertTrue(sess.run([pop_key_v, pop_v], feed_dict={pi:1})
|
||||
== [1, { 'v': 1 }])
|
||||
# Nothing is left
|
||||
self.assertTrue(sess.run([size, isize]) == [0, 0])
|
||||
|
||||
def testPartialIndexGets(self):
|
||||
with ops.Graph().as_default() as G:
|
||||
with ops.device('/cpu:0'):
|
||||
x = array_ops.placeholder(dtypes.float32)
|
||||
f = array_ops.placeholder(dtypes.float32)
|
||||
v = array_ops.placeholder(dtypes.float32)
|
||||
pi = array_ops.placeholder(dtypes.int64)
|
||||
pei = array_ops.placeholder(dtypes.int64)
|
||||
gi = array_ops.placeholder(dtypes.int64)
|
||||
with ops.device(test.gpu_device_name()):
|
||||
# Test again with partial index gets
|
||||
stager = data_flow_ops.MapStagingArea(
|
||||
[dtypes.float32, dtypes.float32, dtypes.float32])
|
||||
stage_xvf = stager.put(pi, [x, v, f], [0, 1, 2])
|
||||
key_xf, get_xf = stager.get(gi, [0, 2])
|
||||
key_v, get_v = stager.get(gi, [1])
|
||||
size = stager.size()
|
||||
isize = stager.incomplete_size()
|
||||
|
||||
G.finalize()
|
||||
|
||||
with self.test_session(use_gpu=True, graph=G) as sess:
|
||||
# Stage complete tuple
|
||||
sess.run(stage_xvf, feed_dict={pi: 0, x: 1, f: 2, v: 3})
|
||||
|
||||
self.assertTrue(sess.run([size, isize]) == [1, 0])
|
||||
|
||||
# Partial get using indices
|
||||
self.assertTrue(sess.run([key_xf, get_xf],
|
||||
feed_dict={gi: 0}) == [0, [1, 2]])
|
||||
|
||||
# Still some of key 0 left
|
||||
self.assertTrue(sess.run([size, isize]) == [1, 0])
|
||||
|
||||
# Partial get of remaining index
|
||||
self.assertTrue(sess.run([key_v, get_v],
|
||||
feed_dict={gi: 0}) == [0, [3]])
|
||||
|
||||
# All gone
|
||||
self.assertTrue(sess.run([size, isize]) == [0, 0])
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -27,7 +27,7 @@ from tensorflow.python.platform import test
|
||||
class StageTest(test.TestCase):
|
||||
|
||||
def testSimple(self):
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
with ops.Graph().as_default() as G:
|
||||
with ops.device('/cpu:0'):
|
||||
x = array_ops.placeholder(dtypes.float32)
|
||||
v = 2. * (array_ops.zeros([128, 128]) + x)
|
||||
@ -36,13 +36,17 @@ class StageTest(test.TestCase):
|
||||
stage = stager.put([v])
|
||||
y = stager.get()
|
||||
y = math_ops.reduce_max(math_ops.matmul(y, y))
|
||||
|
||||
G.finalize()
|
||||
|
||||
with self.test_session(use_gpu=True, graph=G) as sess:
|
||||
sess.run(stage, feed_dict={x: -1})
|
||||
for i in range(10):
|
||||
_, yval = sess.run([stage, y], feed_dict={x: i})
|
||||
self.assertAllClose(4 * (i - 1) * (i - 1) * 128, yval, rtol=1e-4)
|
||||
|
||||
def testMultiple(self):
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
with ops.Graph().as_default() as G:
|
||||
with ops.device('/cpu:0'):
|
||||
x = array_ops.placeholder(dtypes.float32)
|
||||
v = 2. * (array_ops.zeros([128, 128]) + x)
|
||||
@ -51,6 +55,10 @@ class StageTest(test.TestCase):
|
||||
stage = stager.put([x, v])
|
||||
z, y = stager.get()
|
||||
y = math_ops.reduce_max(z * math_ops.matmul(y, y))
|
||||
|
||||
G.finalize()
|
||||
|
||||
with self.test_session(use_gpu=True, graph=G) as sess:
|
||||
sess.run(stage, feed_dict={x: -1})
|
||||
for i in range(10):
|
||||
_, yval = sess.run([stage, y], feed_dict={x: i})
|
||||
@ -58,7 +66,7 @@ class StageTest(test.TestCase):
|
||||
4 * (i - 1) * (i - 1) * (i - 1) * 128, yval, rtol=1e-4)
|
||||
|
||||
def testDictionary(self):
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
with ops.Graph().as_default() as G:
|
||||
with ops.device('/cpu:0'):
|
||||
x = array_ops.placeholder(dtypes.float32)
|
||||
v = 2. * (array_ops.zeros([128, 128]) + x)
|
||||
@ -72,6 +80,10 @@ class StageTest(test.TestCase):
|
||||
z = ret['x']
|
||||
y = ret['v']
|
||||
y = math_ops.reduce_max(z * math_ops.matmul(y, y))
|
||||
|
||||
G.finalize()
|
||||
|
||||
with self.test_session(use_gpu=True, graph=G) as sess:
|
||||
sess.run(stage, feed_dict={x: -1})
|
||||
for i in range(10):
|
||||
_, yval = sess.run([stage, y], feed_dict={x: i})
|
||||
@ -81,29 +93,35 @@ class StageTest(test.TestCase):
|
||||
def testColocation(self):
|
||||
gpu_dev = test.gpu_device_name()
|
||||
|
||||
with ops.device('/cpu:0'):
|
||||
x = array_ops.placeholder(dtypes.float32)
|
||||
v = 2. * (array_ops.zeros([128, 128]) + x)
|
||||
with ops.device(gpu_dev):
|
||||
stager = data_flow_ops.StagingArea([dtypes.float32])
|
||||
y = stager.put([v])
|
||||
self.assertEqual(y.device, '/device:GPU:0' if gpu_dev
|
||||
else gpu_dev)
|
||||
with ops.device('/cpu:0'):
|
||||
x = stager.get()
|
||||
self.assertEqual(x.device, '/device:CPU:0')
|
||||
with ops.Graph().as_default() as G:
|
||||
with ops.device('/cpu:0'):
|
||||
x = array_ops.placeholder(dtypes.float32)
|
||||
v = 2. * (array_ops.zeros([128, 128]) + x)
|
||||
with ops.device(gpu_dev):
|
||||
stager = data_flow_ops.StagingArea([dtypes.float32])
|
||||
y = stager.put([v])
|
||||
self.assertEqual(y.device, '/device:GPU:0' if gpu_dev
|
||||
else gpu_dev)
|
||||
with ops.device('/cpu:0'):
|
||||
x = stager.get()
|
||||
self.assertEqual(x.device, '/device:CPU:0')
|
||||
|
||||
G.finalize()
|
||||
|
||||
def testPeek(self):
|
||||
with ops.device('/cpu:0'):
|
||||
x = array_ops.placeholder(dtypes.int32, name='x')
|
||||
p = array_ops.placeholder(dtypes.int32, name='p')
|
||||
with ops.device(test.gpu_device_name()):
|
||||
stager = data_flow_ops.StagingArea([dtypes.int32, ], shapes=[[]])
|
||||
stage = stager.put([x])
|
||||
peek = stager.peek(p)
|
||||
ret = stager.get()
|
||||
with ops.Graph().as_default() as G:
|
||||
with ops.device('/cpu:0'):
|
||||
x = array_ops.placeholder(dtypes.int32, name='x')
|
||||
p = array_ops.placeholder(dtypes.int32, name='p')
|
||||
with ops.device(test.gpu_device_name()):
|
||||
stager = data_flow_ops.StagingArea([dtypes.int32, ], shapes=[[]])
|
||||
stage = stager.put([x])
|
||||
peek = stager.peek(p)
|
||||
ret = stager.get()
|
||||
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
G.finalize()
|
||||
|
||||
with self.test_session(use_gpu=True, graph=G) as sess:
|
||||
for i in range(10):
|
||||
sess.run(stage, feed_dict={x:i})
|
||||
|
||||
@ -111,20 +129,23 @@ class StageTest(test.TestCase):
|
||||
self.assertTrue(sess.run(peek, feed_dict={p:i}) == i)
|
||||
|
||||
def testSizeAndClear(self):
|
||||
with ops.device('/cpu:0'):
|
||||
x = array_ops.placeholder(dtypes.float32, name='x')
|
||||
v = 2. * (array_ops.zeros([128, 128]) + x)
|
||||
with ops.device(test.gpu_device_name()):
|
||||
stager = data_flow_ops.StagingArea(
|
||||
[dtypes.float32, dtypes.float32],
|
||||
shapes=[[], [128, 128]],
|
||||
names=['x', 'v'])
|
||||
stage = stager.put({'x': x, 'v': v})
|
||||
ret = stager.get()
|
||||
size = stager.size()
|
||||
clear = stager.clear()
|
||||
with ops.Graph().as_default() as G:
|
||||
with ops.device('/cpu:0'):
|
||||
x = array_ops.placeholder(dtypes.float32, name='x')
|
||||
v = 2. * (array_ops.zeros([128, 128]) + x)
|
||||
with ops.device(test.gpu_device_name()):
|
||||
stager = data_flow_ops.StagingArea(
|
||||
[dtypes.float32, dtypes.float32],
|
||||
shapes=[[], [128, 128]],
|
||||
names=['x', 'v'])
|
||||
stage = stager.put({'x': x, 'v': v})
|
||||
ret = stager.get()
|
||||
size = stager.size()
|
||||
clear = stager.clear()
|
||||
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
G.finalize()
|
||||
|
||||
with self.test_session(use_gpu=True, graph=G) as sess:
|
||||
sess.run(stage, feed_dict={x: -1})
|
||||
self.assertEqual(sess.run(size), 1)
|
||||
sess.run(stage, feed_dict={x: -1})
|
||||
@ -135,14 +156,17 @@ class StageTest(test.TestCase):
|
||||
def testCapacity(self):
|
||||
capacity = 3
|
||||
|
||||
with ops.device('/cpu:0'):
|
||||
x = array_ops.placeholder(dtypes.int32, name='x')
|
||||
with ops.device(test.gpu_device_name()):
|
||||
stager = data_flow_ops.StagingArea([dtypes.int32, ],
|
||||
capacity=capacity, shapes=[[]])
|
||||
stage = stager.put([x])
|
||||
ret = stager.get()
|
||||
size = stager.size()
|
||||
with ops.Graph().as_default() as G:
|
||||
with ops.device('/cpu:0'):
|
||||
x = array_ops.placeholder(dtypes.int32, name='x')
|
||||
with ops.device(test.gpu_device_name()):
|
||||
stager = data_flow_ops.StagingArea([dtypes.int32, ],
|
||||
capacity=capacity, shapes=[[]])
|
||||
stage = stager.put([x])
|
||||
ret = stager.get()
|
||||
size = stager.size()
|
||||
|
||||
G.finalize()
|
||||
|
||||
from six.moves import queue as Queue
|
||||
import threading
|
||||
@ -151,7 +175,7 @@ class StageTest(test.TestCase):
|
||||
n = 5
|
||||
missed = 0
|
||||
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
with self.test_session(use_gpu=True, graph=G) as sess:
|
||||
# Stage data in a separate thread which will block
|
||||
# when it hits the staging area's capacity and thus
|
||||
# not fill the queue with n tokens
|
||||
@ -193,14 +217,17 @@ class StageTest(test.TestCase):
|
||||
chunk = 200*1024 # 256K
|
||||
capacity = memory_limit // chunk
|
||||
|
||||
with ops.device('/cpu:0'):
|
||||
x = array_ops.placeholder(dtypes.uint8, name='x')
|
||||
with ops.device(test.gpu_device_name()):
|
||||
stager = data_flow_ops.StagingArea([dtypes.uint8, ],
|
||||
memory_limit=memory_limit, shapes=[[]])
|
||||
stage = stager.put([x])
|
||||
ret = stager.get()
|
||||
size = stager.size()
|
||||
with ops.Graph().as_default() as G:
|
||||
with ops.device('/cpu:0'):
|
||||
x = array_ops.placeholder(dtypes.uint8, name='x')
|
||||
with ops.device(test.gpu_device_name()):
|
||||
stager = data_flow_ops.StagingArea([dtypes.uint8, ],
|
||||
memory_limit=memory_limit, shapes=[[]])
|
||||
stage = stager.put([x])
|
||||
ret = stager.get()
|
||||
size = stager.size()
|
||||
|
||||
G.finalize()
|
||||
|
||||
from six.moves import queue as Queue
|
||||
import threading
|
||||
@ -210,7 +237,7 @@ class StageTest(test.TestCase):
|
||||
n = 5
|
||||
missed = 0
|
||||
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
with self.test_session(use_gpu=True, graph=G) as sess:
|
||||
# Stage data in a separate thread which will block
|
||||
# when it hits the staging area's capacity and thus
|
||||
# not fill the queue with n tokens
|
||||
|
@ -1506,7 +1506,7 @@ class BaseStagingArea(object):
|
||||
|
||||
return tensors
|
||||
|
||||
def _get_return_value(self, tensors):
|
||||
def _get_return_value(self, tensors, indices):
|
||||
"""Return the value to return from a get op.
|
||||
|
||||
If the staging area has names, return a dictionary with the
|
||||
@ -1515,6 +1515,7 @@ class BaseStagingArea(object):
|
||||
|
||||
Args:
|
||||
tensors: List of tensors from the get op.
|
||||
indices: Indices of associated names and shapes
|
||||
|
||||
Returns:
|
||||
A single tensor, a list of tensors, or a dictionary
|
||||
@ -1524,13 +1525,13 @@ class BaseStagingArea(object):
|
||||
tensors = self._create_device_transfers(tensors)
|
||||
|
||||
# Sets shape
|
||||
for output, shape in zip(tensors, self._shapes):
|
||||
output.set_shape(shape)
|
||||
for output, i in zip(tensors, indices):
|
||||
output.set_shape(self._shapes[i])
|
||||
|
||||
if self._names:
|
||||
# The returned values in `tensors` are in the same order as
|
||||
# the names in `self._names`.
|
||||
return {n: tensors[i] for i, n in enumerate(self._names)}
|
||||
return {self._names[i]: t for t, i in zip(tensors, indices)}
|
||||
elif len(tensors) == 1:
|
||||
return tensors[0]
|
||||
else:
|
||||
@ -1646,7 +1647,8 @@ class StagingArea(BaseStagingArea):
|
||||
self._scope_vals(values)) as scope:
|
||||
|
||||
# Hard-code indices for this staging area
|
||||
indices = range(len(values)) if isinstance(values, (list, tuple)) else None
|
||||
indices = (list(six.moves.range(len(values)))
|
||||
if isinstance(values, (list, tuple)) else None)
|
||||
vals, _ = self._check_put_dtypes(values, indices)
|
||||
|
||||
with ops.colocate_with(self._coloc_op):
|
||||
@ -1660,7 +1662,8 @@ class StagingArea(BaseStagingArea):
|
||||
with ops.colocate_with(self._coloc_op):
|
||||
ret = get_fn()
|
||||
|
||||
return self._get_return_value(ret)
|
||||
indices = list(six.moves.range(len(self._dtypes))) # Hard coded
|
||||
return self._get_return_value(ret, indices)
|
||||
|
||||
def get(self, name=None):
|
||||
"""Gets one element from this staging area.
|
||||
@ -1802,11 +1805,16 @@ class MapStagingArea(BaseStagingArea):
|
||||
All get() and peek() commands block if the requested
|
||||
(key, value) pair is not present in the staging area.
|
||||
|
||||
Incomplete puts are supported and will be placed in an incomplete
|
||||
hash until such time as all values associated with the key have
|
||||
Partial puts are supported and will be placed in an incomplete
|
||||
map until such time as all values associated with the key have
|
||||
been inserted. Once completed, this (key, value) pair will be
|
||||
inserted into the main data structure. Data in the incomplete set
|
||||
inserted into the map. Data in the incomplete map
|
||||
counts towards the memory limit, but not towards capacity limit.
|
||||
|
||||
Partial gets from the map are also supported.
|
||||
This removes the partially requested tensors from the entry,
|
||||
but the entry is only removed from the map once all tensors
|
||||
associated with it are removed.
|
||||
"""
|
||||
|
||||
def __init__(self, dtypes, shapes=None, names=None, shared_name=None,
|
||||
@ -1901,7 +1909,38 @@ class MapStagingArea(BaseStagingArea):
|
||||
memory_limit=self._memory_limit)
|
||||
return op
|
||||
|
||||
def peek(self, key, name=None):
|
||||
def _get_indices_and_dtypes(self, indices=None):
|
||||
if indices is None:
|
||||
indices = list(six.moves.range(len(self._dtypes)))
|
||||
|
||||
if not isinstance(indices, (tuple, list)):
|
||||
raise TypeError("Invalid indices type '%s'" % type(indices))
|
||||
|
||||
if len(indices) == 0:
|
||||
raise ValueError("Empty indices")
|
||||
|
||||
if all(isinstance(i, str) for i in indices):
|
||||
if self._names is None:
|
||||
raise ValueError("String indices provided '%s', but this Staging Area "
|
||||
"was not created with names." % indices)
|
||||
|
||||
try:
|
||||
indices = [self._names.index(n) for n in indices]
|
||||
except ValueError:
|
||||
raise ValueError("Named index '%s' not in "
|
||||
"Staging Area names '%s'" % (n, self._names))
|
||||
elif all(isinstance(i, int) for i in indices):
|
||||
pass
|
||||
else:
|
||||
raise TypeError("Mixed types in indices '%s'. "
|
||||
"May only be str or int" % indices)
|
||||
|
||||
dtypes = [self._dtypes[i] for i in indices]
|
||||
|
||||
return indices, dtypes
|
||||
|
||||
|
||||
def peek(self, key, indices=None, name=None):
|
||||
"""
|
||||
Peeks at staging area data associated with the key.
|
||||
|
||||
@ -1910,6 +1949,10 @@ class MapStagingArea(BaseStagingArea):
|
||||
|
||||
Args:
|
||||
key: Key associated with the required data
|
||||
indices: Partial list of tensors to retrieve (optional).
|
||||
A list of integer or string indices.
|
||||
String indices are only valid if the Staging Area
|
||||
has names associated with it.
|
||||
name: A name for the operation (optional)
|
||||
|
||||
Returns:
|
||||
@ -1919,16 +1962,19 @@ class MapStagingArea(BaseStagingArea):
|
||||
if name is None:
|
||||
name = "%s_pop" % self._name
|
||||
|
||||
indices, dtypes = self._get_indices_and_dtypes(indices)
|
||||
|
||||
with ops.colocate_with(self._coloc_op):
|
||||
result = self._peek_fn(key, shared_name=self._name,
|
||||
dtypes=self._dtypes,
|
||||
indices=indices,
|
||||
dtypes=dtypes,
|
||||
name=name,
|
||||
capacity=self._capacity,
|
||||
memory_limit=self._memory_limit)
|
||||
|
||||
return self._get_return_value(result)
|
||||
return self._get_return_value(result, indices)
|
||||
|
||||
def get(self, key=None, name=None):
|
||||
def get(self, key=None, indices=None, name=None):
|
||||
"""
|
||||
If the key is provided, the associated (key, value)
|
||||
is returned from the staging area. If the key is not
|
||||
@ -1944,18 +1990,21 @@ class MapStagingArea(BaseStagingArea):
|
||||
|
||||
Args:
|
||||
key: Key associated with the required data (Optional)
|
||||
indices: Partial list of tensors to retrieve (optional).
|
||||
A list of integer or string indices.
|
||||
String indices are only valid if the Staging Area
|
||||
has names associated with it.
|
||||
name: A name for the operation (optional)
|
||||
|
||||
Returns:
|
||||
The created op
|
||||
"""
|
||||
if key is None:
|
||||
return self._popitem(name)
|
||||
return self._popitem(indices=indices, name=name)
|
||||
else:
|
||||
return self._pop(key, name)
|
||||
return self._pop(key, indices=indices, name=name)
|
||||
|
||||
|
||||
def _pop(self, key, name=None):
|
||||
def _pop(self, key, indices=None, name=None):
|
||||
"""
|
||||
Remove and return the associated (key, value)
|
||||
is returned from the staging area. If the key is not
|
||||
@ -1964,6 +2013,10 @@ class MapStagingArea(BaseStagingArea):
|
||||
|
||||
Args:
|
||||
key: Key associated with the required data
|
||||
indices: Partial list of tensors to retrieve (optional).
|
||||
A list of integer or string indices.
|
||||
String indices are only valid if the Staging Area
|
||||
has names associated with it.
|
||||
name: A name for the operation (optional)
|
||||
|
||||
Returns:
|
||||
@ -1972,16 +2025,19 @@ class MapStagingArea(BaseStagingArea):
|
||||
if name is None:
|
||||
name = "%s_get" % self._name
|
||||
|
||||
indices, dtypes = self._get_indices_and_dtypes(indices)
|
||||
|
||||
with ops.colocate_with(self._coloc_op):
|
||||
result = self._pop_fn(key, shared_name=self._name,
|
||||
dtypes=self._dtypes,
|
||||
indices=indices,
|
||||
dtypes=dtypes,
|
||||
name=name,
|
||||
capacity=self._capacity,
|
||||
memory_limit=self._memory_limit)
|
||||
|
||||
return key, self._get_return_value(result)
|
||||
return key, self._get_return_value(result, indices)
|
||||
|
||||
def _popitem(self, name=None):
|
||||
def _popitem(self, indices=None, name=None):
|
||||
"""
|
||||
If the staging area is ordered,
|
||||
the (key, value) with the smallest key will be returned.
|
||||
@ -1992,6 +2048,10 @@ class MapStagingArea(BaseStagingArea):
|
||||
|
||||
Args:
|
||||
key: Key associated with the required data
|
||||
indices: Partial list of tensors to retrieve (optional).
|
||||
A list of integer or string indices.
|
||||
String indices are only valid if the Staging Area
|
||||
has names associated with it.
|
||||
name: A name for the operation (optional)
|
||||
|
||||
Returns:
|
||||
@ -2000,9 +2060,12 @@ class MapStagingArea(BaseStagingArea):
|
||||
if name is None:
|
||||
name = "%s_get_nokey" % self._name
|
||||
|
||||
indices, dtypes = self._get_indices_and_dtypes(indices)
|
||||
|
||||
with ops.colocate_with(self._coloc_op):
|
||||
key, result = self._popitem_fn(shared_name=self._name,
|
||||
dtypes=self._dtypes,
|
||||
indices=indices,
|
||||
dtypes=dtypes,
|
||||
name=name,
|
||||
capacity=self._capacity,
|
||||
memory_limit=self._memory_limit)
|
||||
@ -2010,7 +2073,7 @@ class MapStagingArea(BaseStagingArea):
|
||||
# Separate keys and results out from
|
||||
# underlying namedtuple
|
||||
key = self._create_device_transfers(key)[0]
|
||||
result = self._get_return_value(result)
|
||||
result = self._get_return_value(result, indices)
|
||||
|
||||
return key, result
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user