[tf.data] NUMA-aware MapAndBatch dataset.
PiperOrigin-RevId: 216395709
This commit is contained in:
parent
12e164d1e7
commit
072fcb995a
@ -0,0 +1,58 @@
|
||||
op {
|
||||
graph_op_name: "ExperimentalNumaMapAndBatchDataset"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "input_dataset"
|
||||
description: <<END
|
||||
A variant tensor representing the input dataset.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "other_arguments"
|
||||
description: <<END
|
||||
A list of tensors, typically values that were captured when building a closure
|
||||
for `f`.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "batch_size"
|
||||
description: <<END
|
||||
A scalar representing the number of elements to accumulate in a
|
||||
batch. It determines the number of concurrent invocations of `f` that process
|
||||
elements from `input_dataset` in parallel.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "num_parallel_calls"
|
||||
description: <<END
|
||||
A scalar representing the maximum number of parallel invocations of the `map_fn`
|
||||
function. Applying the `map_fn` on consecutive input elements in parallel has
|
||||
the potential to improve input pipeline throughput.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "drop_remainder"
|
||||
description: <<END
|
||||
A scalar representing whether the last batch should be dropped in case its size
|
||||
is smaller than desired.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "f"
|
||||
description: <<END
|
||||
A function to apply to the outputs of `input_dataset`.
|
||||
END
|
||||
}
|
||||
summary: "Creates a dataset that fuses mapping with batching."
|
||||
description: <<END
|
||||
Creates a dataset that applies `f` to the outputs of `input_dataset` and then
|
||||
batches `batch_size` of them.
|
||||
|
||||
Unlike a "MapDataset", which applies `f` sequentially, this dataset invokes up
|
||||
to `batch_size * num_parallel_batches` copies of `f` in parallel.
|
||||
|
||||
Unlike "MapAndBatchDatasetV2", this dataset uses a NUMA-aware thread scheduling
|
||||
policy. Because it uses the single-threaded executor, it only supports the
|
||||
function-based control flow ops.
|
||||
END
|
||||
}
|
@ -335,7 +335,7 @@ class Model {
|
||||
if (name_ == "Map") {
|
||||
return Type::MAP;
|
||||
}
|
||||
if (name_ == "MapAndBatch") {
|
||||
if (name_ == "MapAndBatch" || name_ == "NumaMapAndBatch") {
|
||||
return Type::MAP_AND_BATCH;
|
||||
}
|
||||
if (name_ == "PaddedBatch") {
|
||||
|
@ -414,6 +414,40 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "map_and_batch_numa_aware_replacement",
|
||||
srcs = ["map_and_batch_numa_aware_replacement.cc"],
|
||||
hdrs = ["map_and_batch_numa_aware_replacement.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":graph_utils",
|
||||
"//tensorflow/core/grappler:mutable_graph_view",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler:op_types",
|
||||
"//tensorflow/core/grappler:utils",
|
||||
"//tensorflow/core/grappler/clusters:cluster",
|
||||
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
|
||||
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
|
||||
] + tf_protos_all(),
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "map_and_batch_numa_aware_replacement_test",
|
||||
srcs = ["map_and_batch_numa_aware_replacement_test.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":graph_test_utils",
|
||||
":graph_utils",
|
||||
":map_and_batch_numa_aware_replacement",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "noop_elimination",
|
||||
srcs = ["noop_elimination.cc"],
|
||||
@ -490,6 +524,7 @@ cc_library(
|
||||
":hoist_random_uniform",
|
||||
":latency_all_edges",
|
||||
":map_and_batch_fusion",
|
||||
":map_and_batch_numa_aware_replacement",
|
||||
":map_and_filter_fusion",
|
||||
":map_fusion",
|
||||
":map_parallelization",
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
@ -44,6 +45,21 @@ NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name,
|
||||
{"output_types", gtl::ArraySlice<TensorShape>{}}});
|
||||
}
|
||||
|
||||
NodeDef MakeMapAndBatchNode(StringPiece name, StringPiece input_node_name,
|
||||
StringPiece batch_size_node_name,
|
||||
StringPiece num_parallel_calls_node_name,
|
||||
StringPiece drop_remainder_node_name,
|
||||
StringPiece function_name) {
|
||||
return test::function::NDef(
|
||||
name, "MapAndBatchDatasetV2",
|
||||
{string(input_node_name), "", string(batch_size_node_name),
|
||||
string(num_parallel_calls_node_name), string(drop_remainder_node_name)},
|
||||
{{"predicate", FunctionDefHelper::FunctionRef(string(function_name))},
|
||||
{"Targuments", {}},
|
||||
{"output_shapes", gtl::ArraySlice<TensorShape>{}},
|
||||
{"output_types", gtl::ArraySlice<TensorShape>{}}});
|
||||
}
|
||||
|
||||
} // end namespace graph_tests_utils
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
@ -29,6 +29,12 @@ NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name,
|
||||
NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name,
|
||||
StringPiece function_name = "IsZero");
|
||||
|
||||
NodeDef MakeMapAndBatchNode(StringPiece name, StringPiece input_node_name,
|
||||
StringPiece batch_size_node_name,
|
||||
StringPiece num_parallel_calls_node_name,
|
||||
StringPiece drop_remainder_node_name,
|
||||
StringPiece function_name = "XTimesTwo");
|
||||
|
||||
} // end namespace graph_tests_utils
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
@ -0,0 +1,62 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/data/map_and_batch_numa_aware_replacement.h"
|
||||
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/grappler/clusters/cluster.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/mutable_graph_view.h"
|
||||
#include "tensorflow/core/grappler/op_types.h"
|
||||
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
|
||||
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
namespace {
|
||||
|
||||
NodeDef MakeNumaAware(const NodeDef& node, MutableGraphView* graph) {
|
||||
NodeDef numa_aware_node = node;
|
||||
graph_utils::SetUniqueGraphNodeName("map_and_batch_numa_aware",
|
||||
graph->GetGraph(), &numa_aware_node);
|
||||
numa_aware_node.set_op("ExperimentalNumaMapAndBatchDataset");
|
||||
return numa_aware_node;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status MapAndBatchNumaAwareReplacement::Optimize(Cluster* cluster,
|
||||
const GrapplerItem& item,
|
||||
GraphDef* output) {
|
||||
*output = item.graph;
|
||||
MutableGraphView graph(output);
|
||||
std::set<string> nodes_to_delete;
|
||||
|
||||
for (const NodeDef& node : item.graph.node()) {
|
||||
if (node.op() != "MapAndBatchDatasetV2") continue;
|
||||
|
||||
auto* numa_node = graph.AddNode(MakeNumaAware(node, &graph));
|
||||
graph.ReplaceInput(node, *numa_node);
|
||||
nodes_to_delete.insert(node.name());
|
||||
}
|
||||
graph.DeleteNodes(nodes_to_delete);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
REGISTER_GRAPH_OPTIMIZER_AS(MapAndBatchNumaAwareReplacement,
|
||||
"map_and_batch_numa_aware_replacement");
|
||||
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
@ -0,0 +1,48 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_AND_BATCH_NUMA_AWARE_REPLACEMENT_H_
|
||||
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_AND_BATCH_NUMA_AWARE_REPLACEMENT_H_
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
class MapAndBatchNumaAwareReplacement : public CustomGraphOptimizer {
|
||||
public:
|
||||
MapAndBatchNumaAwareReplacement() = default;
|
||||
~MapAndBatchNumaAwareReplacement() override = default;
|
||||
|
||||
string name() const override {
|
||||
return "map_and_batch_numa_aware_replacement";
|
||||
}
|
||||
|
||||
Status Init(
|
||||
const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
GraphDef* output) override;
|
||||
|
||||
void Feedback(Cluster* cluster, const GrapplerItem& item,
|
||||
const GraphDef& optimize_output, double result) override {}
|
||||
};
|
||||
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_AND_BATCH_NUMA_AWARE_REPLACEMENT_H_
|
@ -0,0 +1,112 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/data/map_and_batch_numa_aware_replacement.h"
|
||||
|
||||
#include "tensorflow/core/framework/attr_value_util.h"
|
||||
#include "tensorflow/core/framework/function_testlib.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
|
||||
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
|
||||
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
namespace {
|
||||
|
||||
TEST(MapAndBatchNumaAwareReplacementTest, ReplaceSimple) {
|
||||
using test::function::NDef;
|
||||
GrapplerItem item;
|
||||
item.graph = test::function::GDef(
|
||||
{
|
||||
NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
|
||||
NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
|
||||
NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
|
||||
NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
|
||||
NDef("batch_size", "Const", {}, {{"value", 3}, {"dtype", DT_INT32}}),
|
||||
NDef("num_parallel_calls", "Const", {},
|
||||
{{"value", 5}, {"dtype", DT_INT32}}),
|
||||
NDef("drop_remainder", "Const", {},
|
||||
{{"value", 0}, {"dtype", DT_BOOL}}),
|
||||
graph_tests_utils::MakeMapAndBatchNode(
|
||||
"map_and_batch", "range", "batch_size", "num_parallel_calls",
|
||||
"drop_remainder"),
|
||||
},
|
||||
// FunctionLib
|
||||
{
|
||||
test::function::XTimesTwo(),
|
||||
});
|
||||
|
||||
MapAndBatchNumaAwareReplacement optimizer;
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
|
||||
EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("map_and_batch", output));
|
||||
EXPECT_FALSE(graph_utils::ContainsNodeWithOp("MapAndBatchDatasetV2", output));
|
||||
EXPECT_TRUE(graph_utils::ContainsNodeWithOp(
|
||||
"ExperimentalNumaMapAndBatchDataset", output));
|
||||
}
|
||||
|
||||
TEST(MapAndBatchNumaAawareReplacementTest, ReplaceWithExtraChild) {
|
||||
using test::function::NDef;
|
||||
GrapplerItem item;
|
||||
item.graph = test::function::GDef(
|
||||
{
|
||||
NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
|
||||
NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
|
||||
NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
|
||||
NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
|
||||
NDef("batch_size", "Const", {}, {{"value", 3}, {"dtype", DT_INT32}}),
|
||||
NDef("num_parallel_calls", "Const", {},
|
||||
{{"value", 5}, {"dtype", DT_INT32}}),
|
||||
NDef("drop_remainder", "Const", {},
|
||||
{{"value", 0}, {"dtype", DT_BOOL}}),
|
||||
graph_tests_utils::MakeMapAndBatchNode(
|
||||
"map_and_batch", "range", "batch_size", "num_parallel_calls",
|
||||
"drop_remainder"),
|
||||
NDef("cache", "CacheDataset", {"map_and_batch"}, {}),
|
||||
},
|
||||
// FunctionLib
|
||||
{
|
||||
test::function::XTimesTwo(),
|
||||
});
|
||||
|
||||
MapAndBatchNumaAwareReplacement optimizer;
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
|
||||
EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("map_and_batch", output));
|
||||
EXPECT_FALSE(graph_utils::ContainsNodeWithOp("MapAndBatchDatasetV2", output));
|
||||
EXPECT_TRUE(graph_utils::ContainsNodeWithOp(
|
||||
"ExperimentalNumaMapAndBatchDataset", output));
|
||||
EXPECT_TRUE(graph_utils::ContainsNodeWithOp("CacheDataset", output));
|
||||
|
||||
int numa_map_and_batch_component_id = graph_utils::FindGraphNodeWithOp(
|
||||
"ExperimentalNumaMapAndBatchDataset", output);
|
||||
auto& numa_map_and_batch_component =
|
||||
output.node(numa_map_and_batch_component_id);
|
||||
EXPECT_EQ(numa_map_and_batch_component.input(0), "range");
|
||||
|
||||
int cache_id = graph_utils::FindGraphNodeWithOp("CacheDataset", output);
|
||||
auto& cache_node = output.node(cache_id);
|
||||
EXPECT_EQ(cache_node.input(0), numa_map_and_batch_component.name());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
@ -102,6 +102,22 @@ tf_kernel_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "numa_map_and_batch_dataset_op",
|
||||
srcs = ["numa_map_and_batch_dataset_op.cc"],
|
||||
deps = [
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:experimental_dataset_ops_op_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core/kernels:inplace_ops",
|
||||
"//tensorflow/core/kernels/data:captured_function",
|
||||
"//tensorflow/core/kernels/data:dataset",
|
||||
"@com_google_absl//absl/memory",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "unique_dataset_op",
|
||||
srcs = ["unique_dataset_op.cc"],
|
||||
@ -132,6 +148,7 @@ tf_kernel_library(
|
||||
":ignore_errors_dataset_op",
|
||||
":indexed_dataset",
|
||||
":lmdb_dataset_op",
|
||||
":numa_map_and_batch_dataset_op",
|
||||
":prefetching_kernels",
|
||||
":threadpool_dataset_op",
|
||||
":unique_dataset_op",
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -324,6 +324,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
|
||||
}
|
||||
|
||||
private:
|
||||
// BatchResult encapsulates the output batch, as well as anciliary
|
||||
// metadata required to execute the fused map-and-batch operation.
|
||||
struct BatchResult {
|
||||
explicit BatchResult(int64 batch_size) {
|
||||
end_of_input = false;
|
||||
@ -331,11 +333,23 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
|
||||
num_elements = 0;
|
||||
output_allocated = false;
|
||||
status = Status::OK();
|
||||
status_offset = -1;
|
||||
}
|
||||
|
||||
void UpdateStatus(const Status& s) {
|
||||
mutex_lock l(mu);
|
||||
status.Update(s);
|
||||
// UpdateStatus updates the batch's aggregate Status.
|
||||
//
|
||||
// In order to ensure that exactly the first non-OK status is returned
|
||||
// (required to make the behavior is observably identical to a
|
||||
// sequential execution of map followed by batch), we must also keep
|
||||
// track of the offset into the batch that produced `s`.
|
||||
void UpdateStatus(const Status& s, int64 offset) {
|
||||
if (TF_PREDICT_FALSE(!s.ok())) {
|
||||
mutex_lock l(mu);
|
||||
if (status.ok() || offset < status_offset) {
|
||||
status = s;
|
||||
status_offset = offset;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mutex mu;
|
||||
@ -344,6 +358,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
|
||||
std::vector<Tensor> output;
|
||||
bool output_allocated GUARDED_BY(mu);
|
||||
Status status GUARDED_BY(mu);
|
||||
int64 status_offset GUARDED_BY(mu);
|
||||
// Counts the number of outstanding calls for this batch.
|
||||
int64 num_calls; // access guarded by owner's mutex
|
||||
};
|
||||
@ -379,7 +394,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
|
||||
std::shared_ptr<std::vector<Tensor>> return_values =
|
||||
std::make_shared<std::vector<Tensor>>();
|
||||
auto done = [this, ctx, result, return_values, offset](Status status) {
|
||||
result->UpdateStatus(status);
|
||||
result->UpdateStatus(status, offset);
|
||||
if (status.ok()) {
|
||||
EnsureOutputAllocated(ctx, result, return_values);
|
||||
for (size_t i = 0; i < return_values->size(); ++i) {
|
||||
@ -389,11 +404,14 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
|
||||
(batch->NumElements() / batch->dim_size(0))) {
|
||||
TensorShape batch_shape = batch->shape();
|
||||
batch_shape.RemoveDim(0);
|
||||
result->UpdateStatus(errors::InvalidArgument(
|
||||
"Cannot add tensor to the batch: number of elements does "
|
||||
"not match. Shapes are: [tensor]: ",
|
||||
tensor.shape().DebugString(),
|
||||
", [batch]: ", batch_shape.DebugString()));
|
||||
result->UpdateStatus(
|
||||
errors::InvalidArgument(
|
||||
"Cannot add tensor to the batch: number of elements "
|
||||
"does "
|
||||
"not match. Shapes are: [tensor]: ",
|
||||
tensor.shape().DebugString(),
|
||||
", [batch]: ", batch_shape.DebugString()),
|
||||
offset);
|
||||
break;
|
||||
}
|
||||
// TODO(mrry): Add a version of DoParallelConcat that allows us to
|
||||
@ -402,7 +420,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
|
||||
Status copy_status = ::tensorflow::functor::DoParallelConcat(
|
||||
*dataset()->device_, tensor, offset, batch);
|
||||
if (!copy_status.ok()) {
|
||||
result->UpdateStatus(copy_status);
|
||||
result->UpdateStatus(copy_status, offset);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@ -138,6 +138,32 @@ REGISTER_OP("ExperimentalAssertNextDataset")
|
||||
return shape_inference::ScalarShape(c);
|
||||
});
|
||||
|
||||
REGISTER_OP("ExperimentalNumaMapAndBatchDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("other_arguments: Targuments")
|
||||
.Input("batch_size: int64")
|
||||
.Input("num_parallel_calls: int64")
|
||||
.Input("drop_remainder: bool")
|
||||
.Output("handle: variant")
|
||||
.Attr("f: func")
|
||||
.Attr("Targuments: list(type) >= 0")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
// Use index from the end to retrieve the Input shapes,
|
||||
// so that to avoid guessing the length of "other_arguments".
|
||||
// batch_size, num_parallel_batches, and drop_remainder are 0-D scalars.
|
||||
shape_inference::ShapeHandle unused;
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->WithRank(c->input(c->num_inputs() - 3), 0, &unused));
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->WithRank(c->input(c->num_inputs() - 2), 0, &unused));
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->WithRank(c->input(c->num_inputs() - 1), 0, &unused));
|
||||
|
||||
return shape_inference::ScalarShape(c);
|
||||
});
|
||||
|
||||
REGISTER_OP("ExperimentalLMDBDataset")
|
||||
.Input("filenames: string")
|
||||
.Output("handle: variant")
|
||||
|
@ -30,6 +30,7 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import script_ops
|
||||
from tensorflow.python.platform import test
|
||||
@ -38,12 +39,17 @@ from tensorflow.python.platform import test
|
||||
class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("Default", None, None),
|
||||
("SequentialCalls", 1, None),
|
||||
("ParallelCalls", 2, None),
|
||||
("ParallelBatches", None, 10),
|
||||
("Default", None, None, False),
|
||||
("SequentialCalls", 1, None, False),
|
||||
("ParallelCalls", 2, None, False),
|
||||
("ParallelBatches", None, 10, False),
|
||||
("DefaultNUMA", None, None, True),
|
||||
("SequentialCallsNUMA", 1, None, True),
|
||||
("ParallelCallsNUMA", 2, None, True),
|
||||
("ParallelBatchesNUMA", None, 10, True),
|
||||
)
|
||||
def testMapAndBatch(self, num_parallel_calls, num_parallel_batches):
|
||||
def testMapAndBatch(self, num_parallel_calls, num_parallel_batches,
|
||||
numa_aware):
|
||||
"""Test a dataset that maps a TF function across its input elements."""
|
||||
# The pipeline is TensorSliceDataset ->
|
||||
# RepeatDataset(count) -> MapAndBatchDataset(square_3, batch_size).
|
||||
@ -57,14 +63,20 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
def _map_fn(x, y, z):
|
||||
return math_ops.square(x), math_ops.square(y), math_ops.square(z)
|
||||
|
||||
iterator = (
|
||||
dataset = (
|
||||
dataset_ops.Dataset.from_tensor_slices(components).repeat(count).apply(
|
||||
batching.map_and_batch(
|
||||
map_func=_map_fn,
|
||||
batch_size=batch_size,
|
||||
num_parallel_calls=num_parallel_calls,
|
||||
num_parallel_batches=num_parallel_batches))
|
||||
.make_initializable_iterator())
|
||||
num_parallel_batches=num_parallel_batches)))
|
||||
|
||||
if numa_aware:
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_numa_aware = True
|
||||
dataset = dataset.with_options(options)
|
||||
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
@ -115,16 +127,25 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
sess.run(init_op, feed_dict={count: 14, batch_size: 0})
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("Even", False),
|
||||
("Uneven", True),
|
||||
("Even", False, False),
|
||||
("Uneven", True, False),
|
||||
("EvenNUMA", False, True),
|
||||
("UnevenNUMA", True, True),
|
||||
)
|
||||
def testMapAndBatchPartialBatch(self, drop_remainder):
|
||||
iterator = (
|
||||
def testMapAndBatchPartialBatch(self, drop_remainder, numa_aware):
|
||||
dataset = (
|
||||
dataset_ops.Dataset.range(10).apply(
|
||||
batching.map_and_batch(
|
||||
lambda x: array_ops.reshape(x * x, [1]),
|
||||
batch_size=4,
|
||||
drop_remainder=drop_remainder)).make_one_shot_iterator())
|
||||
drop_remainder=drop_remainder)))
|
||||
|
||||
if numa_aware:
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_numa_aware = True
|
||||
dataset = dataset.with_options(options)
|
||||
iterator = dataset.make_one_shot_iterator()
|
||||
|
||||
if drop_remainder:
|
||||
self.assertEqual([4, 1], iterator.output_shapes.as_list())
|
||||
else:
|
||||
@ -138,11 +159,21 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
def testMapAndBatchYieldsPartialBatch(self):
|
||||
iterator = (dataset_ops.Dataset.range(10)
|
||||
.apply(batching.map_and_batch(
|
||||
lambda x: array_ops.reshape(x * x, [1]), 4))
|
||||
.make_one_shot_iterator())
|
||||
@parameterized.named_parameters(
|
||||
("Normal", False),
|
||||
("NUMA", True),
|
||||
)
|
||||
def testMapAndBatchYieldsPartialBatch(self, numa_aware):
|
||||
dataset = (
|
||||
dataset_ops.Dataset.range(10).apply(
|
||||
batching.map_and_batch(lambda x: array_ops.reshape(x * x, [1]), 4)))
|
||||
|
||||
if numa_aware:
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_numa_aware = True
|
||||
dataset = dataset.with_options(options)
|
||||
|
||||
iterator = dataset.make_one_shot_iterator()
|
||||
self.assertEqual([None, 1], iterator.output_shapes.as_list())
|
||||
next_element = iterator.get_next()
|
||||
with self.cached_session() as sess:
|
||||
@ -152,10 +183,19 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
def testMapAndBatchParallelGetNext(self):
|
||||
iterator = (dataset_ops.Dataset.range(50000)
|
||||
.apply(batching.map_and_batch(lambda x: x, batch_size=100))
|
||||
.make_one_shot_iterator())
|
||||
@parameterized.named_parameters(
|
||||
("Normal", False),
|
||||
("NUMA", True),
|
||||
)
|
||||
def testMapAndBatchParallelGetNext(self, numa_aware):
|
||||
dataset = dataset_ops.Dataset.range(50000).apply(
|
||||
batching.map_and_batch(lambda x: x, batch_size=100))
|
||||
if numa_aware:
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_numa_aware = True
|
||||
dataset = dataset.with_options(options)
|
||||
iterator = dataset.make_one_shot_iterator()
|
||||
|
||||
elements = []
|
||||
for _ in range(100):
|
||||
elements.append(iterator.get_next())
|
||||
@ -165,17 +205,26 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
got.sort(key=lambda x: x[0])
|
||||
expected = []
|
||||
for j in range(100):
|
||||
expected.append(range(i*10000+j*100, i*10000+(j+1)*100))
|
||||
expected.append(range(i * 10000 + j * 100, i * 10000 + (j + 1) * 100))
|
||||
self.assertAllEqual(got, expected)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(elements)
|
||||
|
||||
def testMapAndBatchParallelGetNextDropRemainder(self):
|
||||
iterator = (
|
||||
dataset_ops.Dataset.range(49999).apply(
|
||||
batching.map_and_batch(
|
||||
lambda x: x, batch_size=100, drop_remainder=True))
|
||||
.make_one_shot_iterator())
|
||||
@parameterized.named_parameters(
|
||||
("Normal", False),
|
||||
("NUMA", True),
|
||||
)
|
||||
def testMapAndBatchParallelGetNextDropRemainder(self, numa_aware):
|
||||
dataset = dataset_ops.Dataset.range(49999).apply(
|
||||
batching.map_and_batch(
|
||||
lambda x: x, batch_size=100, drop_remainder=True))
|
||||
|
||||
if numa_aware:
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_numa_aware = True
|
||||
dataset = dataset.with_options(options)
|
||||
iterator = dataset.make_one_shot_iterator()
|
||||
|
||||
elements = []
|
||||
for _ in range(100):
|
||||
elements.append(iterator.get_next())
|
||||
@ -185,19 +234,29 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
got.sort(key=lambda x: x[0])
|
||||
expected = []
|
||||
for j in range(100):
|
||||
expected.append(range(i*10000+j*100, i*10000+(j+1)*100))
|
||||
expected.append(range(i * 10000 + j * 100, i * 10000 + (j + 1) * 100))
|
||||
self.assertAllEqual(got, expected)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(elements)
|
||||
|
||||
def testMapAndBatchSparse(self):
|
||||
@parameterized.named_parameters(
|
||||
("Normal", False),
|
||||
("NUMA", True),
|
||||
)
|
||||
def testMapAndBatchSparse(self, numa_aware):
|
||||
|
||||
def _sparse(i):
|
||||
return sparse_tensor.SparseTensorValue(
|
||||
indices=[[0]], values=(i * [1]), dense_shape=[1])
|
||||
|
||||
iterator = dataset_ops.Dataset.range(10).apply(
|
||||
batching.map_and_batch(_sparse, 5)).make_initializable_iterator()
|
||||
dataset = dataset_ops.Dataset.range(10).apply(
|
||||
batching.map_and_batch(_sparse, 5))
|
||||
if numa_aware:
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_numa_aware = True
|
||||
dataset = dataset.with_options(options)
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
@ -214,21 +273,33 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testMapAndBatchFails(self):
|
||||
@parameterized.named_parameters(
|
||||
("Normal", False),
|
||||
("NUMA", True),
|
||||
)
|
||||
def testMapAndBatchFails(self, numa_aware):
|
||||
"""Test a dataset that maps a TF function across its input elements."""
|
||||
dataset = dataset_ops.Dataset.from_tensors(
|
||||
array_ops.check_numerics(
|
||||
constant_op.constant(1.0) / constant_op.constant(0.0), "oops"))
|
||||
batch_size = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
iterator = (
|
||||
dataset.apply(batching.map_and_batch(lambda x: x, batch_size))
|
||||
.make_initializable_iterator())
|
||||
dataset = dataset.apply(batching.map_and_batch(lambda x: x, batch_size))
|
||||
if numa_aware:
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_numa_aware = True
|
||||
dataset = dataset.with_options(options)
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
|
||||
init_op = iterator.initializer
|
||||
with self.cached_session() as sess:
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"):
|
||||
sess.run(init_op, feed_dict={batch_size: 14})
|
||||
|
||||
def testMapAndBatchShapeMismatch(self):
|
||||
@parameterized.named_parameters(
|
||||
("Normal", False),
|
||||
("NUMA", True),
|
||||
)
|
||||
def testMapAndBatchShapeMismatch(self, numa_aware):
|
||||
"""Test a dataset that maps a TF function across its input elements."""
|
||||
|
||||
def generator():
|
||||
@ -240,9 +311,13 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
dataset = dataset_ops.Dataset.from_generator(
|
||||
generator, output_types=dtypes.int32)
|
||||
batch_size = 4
|
||||
iterator = (
|
||||
dataset.apply(batching.map_and_batch(lambda x: x, batch_size))
|
||||
.make_initializable_iterator())
|
||||
dataset = dataset.apply(batching.map_and_batch(lambda x: x, batch_size))
|
||||
if numa_aware:
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_numa_aware = True
|
||||
dataset = dataset.with_options(options)
|
||||
iterator = dataset.make_initializable_iterator()
|
||||
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
with self.cached_session() as sess:
|
||||
@ -251,7 +326,11 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
"number of elements does not match"):
|
||||
sess.run(get_next)
|
||||
|
||||
def testMapAndBatchImplicitDispose(self):
|
||||
@parameterized.named_parameters(
|
||||
("Normal", False),
|
||||
("NUMA", True),
|
||||
)
|
||||
def testMapAndBatchImplicitDispose(self, numa_aware):
|
||||
# Tests whether a map and batch dataset will be cleaned up correctly when
|
||||
# the pipeline does not run it until exhaustion.
|
||||
# The pipeline is TensorSliceDataset -> RepeatDataset(1000) ->
|
||||
@ -266,6 +345,10 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices(components).repeat(
|
||||
1000).apply(batching.map_and_batch(_map_fn, batch_size=100))
|
||||
dataset = dataset.prefetch(5)
|
||||
if numa_aware:
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_numa_aware = True
|
||||
dataset = dataset.with_options(options)
|
||||
iterator = dataset.make_one_shot_iterator()
|
||||
get_next = iterator.get_next()
|
||||
|
||||
@ -274,26 +357,38 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
sess.run(get_next)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("1", 0),
|
||||
("2", 5),
|
||||
("3", 10),
|
||||
("4", 90),
|
||||
("5", 95),
|
||||
("6", 99),
|
||||
("1", 0, False),
|
||||
("2", 5, False),
|
||||
("3", 10, False),
|
||||
("4", 90, False),
|
||||
("5", 95, False),
|
||||
("6", 99, False),
|
||||
("1NUMA", 0, True),
|
||||
("2NUMA", 5, True),
|
||||
("3NUMA", 10, True),
|
||||
("4NUMA", 90, True),
|
||||
("5NUMA", 95, True),
|
||||
("6NUMA", 99, True),
|
||||
)
|
||||
def testMapAndBatchOutOfRangeError(self, threshold):
|
||||
def testMapAndBatchOutOfRangeError(self, threshold, numa_aware):
|
||||
|
||||
def raising_py_fn(i):
|
||||
if i >= threshold:
|
||||
if i == threshold:
|
||||
raise StopIteration()
|
||||
elif i > threshold:
|
||||
raise RuntimeError("Alternate error; you shouldn't see me! (i: %s)" % i)
|
||||
else:
|
||||
return i
|
||||
|
||||
iterator = (
|
||||
dataset_ops.Dataset.range(100).apply(
|
||||
batching.map_and_batch(
|
||||
lambda x: script_ops.py_func(raising_py_fn, [x], dtypes.int64),
|
||||
batch_size=10)).make_one_shot_iterator())
|
||||
dataset = dataset_ops.Dataset.range(100).apply(
|
||||
batching.map_and_batch(
|
||||
lambda x: script_ops.py_func(raising_py_fn, [x], dtypes.int64),
|
||||
batch_size=10))
|
||||
if numa_aware:
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_numa_aware = True
|
||||
dataset = dataset.with_options(options)
|
||||
iterator = dataset.make_one_shot_iterator()
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
@ -307,25 +402,42 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
sess.run(get_next)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("1", False, dtypes.bool),
|
||||
("2", -42, dtypes.int8),
|
||||
("3", -42, dtypes.int16),
|
||||
("4", -42, dtypes.int32),
|
||||
("5", -42, dtypes.int64),
|
||||
("6", 42, dtypes.uint8),
|
||||
("7", 42, dtypes.uint16),
|
||||
("8", 42.0, dtypes.float16),
|
||||
("9", 42.0, dtypes.float32),
|
||||
("10", 42.0, dtypes.float64),
|
||||
("11", b"hello", dtypes.string),
|
||||
("1", False, dtypes.bool, False),
|
||||
("2", -42, dtypes.int8, False),
|
||||
("3", -42, dtypes.int16, False),
|
||||
("4", -42, dtypes.int32, False),
|
||||
("5", -42, dtypes.int64, False),
|
||||
("6", 42, dtypes.uint8, False),
|
||||
("7", 42, dtypes.uint16, False),
|
||||
("8", 42.0, dtypes.float16, False),
|
||||
("9", 42.0, dtypes.float32, False),
|
||||
("10", 42.0, dtypes.float64, False),
|
||||
("11", b"hello", dtypes.string, False),
|
||||
("1NUMA", False, dtypes.bool, True),
|
||||
("2NUMA", -42, dtypes.int8, True),
|
||||
("3NUMA", -42, dtypes.int16, True),
|
||||
("4NUMA", -42, dtypes.int32, True),
|
||||
("5NUMA", -42, dtypes.int64, True),
|
||||
("6NUMA", 42, dtypes.uint8, True),
|
||||
("7NUMA", 42, dtypes.uint16, True),
|
||||
("8NUMA", 42.0, dtypes.float16, True),
|
||||
("9NUMA", 42.0, dtypes.float32, True),
|
||||
("10NUMA", 42.0, dtypes.float64, True),
|
||||
("11NUMA", b"hello", dtypes.string, True),
|
||||
)
|
||||
def testMapAndBatchTypes(self, element, dtype):
|
||||
def testMapAndBatchTypes(self, element, dtype, numa_aware):
|
||||
|
||||
def gen():
|
||||
yield element
|
||||
|
||||
dataset = dataset_ops.Dataset.from_generator(gen, dtype).repeat(100).apply(
|
||||
batching.map_and_batch(lambda x: x, batch_size=10))
|
||||
|
||||
if numa_aware:
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_numa_aware = True
|
||||
dataset = dataset.with_options(options)
|
||||
|
||||
get_next = dataset.make_one_shot_iterator().get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
@ -363,6 +475,40 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
sess.run(iterator.initializer, feed_dict={captured_t: 42})
|
||||
self.assertAllEqual([42] * 10, sess.run(get_next))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("Normal", False),
|
||||
("NUMA", True),
|
||||
)
|
||||
def testMapAndBatchControlFlow(self, numa_aware):
|
||||
|
||||
def map_fn(x):
|
||||
previous_cond_v2_value = control_flow_ops.ENABLE_COND_V2
|
||||
control_flow_ops.ENABLE_COND_V2 = True
|
||||
return_value = control_flow_ops.cond(x < 50, lambda: x + 1, lambda: x * x)
|
||||
control_flow_ops.ENABLE_COND_V2 = previous_cond_v2_value
|
||||
return return_value
|
||||
|
||||
dataset = dataset_ops.Dataset.range(100).apply(
|
||||
batching.map_and_batch(map_fn, batch_size=10))
|
||||
if numa_aware:
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_numa_aware = True
|
||||
dataset = dataset.with_options(options)
|
||||
iterator = dataset.make_one_shot_iterator()
|
||||
get_next = iterator.get_next()
|
||||
with self.cached_session() as sess:
|
||||
for i in range(10):
|
||||
print("Case %d" % i)
|
||||
if i < 5:
|
||||
self.assertAllEqual([i * 10 + j + 1 for j in range(10)],
|
||||
sess.run(get_next))
|
||||
else:
|
||||
self.assertAllEqual(
|
||||
[((i * 10) + j) * ((i * 10) + j) for j in range(10)],
|
||||
sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -161,6 +161,7 @@ py_test(
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//third_party/py/numpy",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
@ -199,6 +200,7 @@ py_test(
|
||||
deps = [
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python/data/experimental/ops:batching",
|
||||
"//tensorflow/python/data/experimental/ops:optimization",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
|
@ -19,6 +19,7 @@ from __future__ import print_function
|
||||
|
||||
import time
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.data.experimental.ops import batching
|
||||
@ -29,7 +30,7 @@ from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class ModelDatasetTest(test_base.DatasetTestBase):
|
||||
class ModelDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
def testModelMap(self):
|
||||
k = 1024 * 1024
|
||||
@ -82,7 +83,11 @@ class ModelDatasetTest(test_base.DatasetTestBase):
|
||||
(np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas),
|
||||
np.max(deltas)))
|
||||
|
||||
def testModelMapAndBatch(self):
|
||||
@parameterized.named_parameters(
|
||||
("Default", False),
|
||||
("NUMA", True),
|
||||
)
|
||||
def testModelMapAndBatch(self, numa_aware):
|
||||
batch_size = 16
|
||||
k = 1024 * 1024
|
||||
dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k),
|
||||
@ -95,6 +100,8 @@ class ModelDatasetTest(test_base.DatasetTestBase):
|
||||
batch_size=batch_size))
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_autotune = True
|
||||
if numa_aware:
|
||||
options.experimental_numa_aware = True
|
||||
iterator = dataset.with_options(options).make_one_shot_iterator()
|
||||
get_next = iterator.get_next()
|
||||
|
||||
|
@ -19,6 +19,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.data.experimental.ops import batching
|
||||
from tensorflow.python.data.experimental.ops import optimization
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
@ -59,6 +60,21 @@ class OptimizeDatasetTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testNumaAwareRewrite(self):
|
||||
dataset = dataset_ops.Dataset.range(10).apply(
|
||||
optimization.assert_next(["NumaMapAndBatch"])).apply(
|
||||
batching.map_and_batch(lambda x: x * x, 10))
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_numa_aware = True
|
||||
dataset = dataset.with_options(options)
|
||||
iterator = dataset.make_one_shot_iterator()
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual([x * x for x in range(10)], sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testOptimizationStatefulFunction(self):
|
||||
dataset = dataset_ops.Dataset.range(10).map(
|
||||
lambda _: random_ops.random_uniform([])).batch(10)
|
||||
|
@ -306,6 +306,21 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "numa_map_and_batch_dataset_serialization_test",
|
||||
size = "medium",
|
||||
srcs = ["numa_map_and_batch_dataset_serialization_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["no_pip"],
|
||||
deps = [
|
||||
":dataset_serialization_test_base",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python/data/experimental/ops:batching",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "map_dataset_serialization_test",
|
||||
size = "medium",
|
||||
|
@ -0,0 +1,95 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for the MapAndBatchDataset serialization."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
|
||||
from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
|
||||
from tensorflow.python.data.experimental.ops import batching
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class MapAndBatchDatasetSerializationTest(
|
||||
dataset_serialization_test_base.DatasetSerializationTestBase):
|
||||
|
||||
def testNumParallelBatches(self):
|
||||
range_size = 11
|
||||
num_repeats = 2
|
||||
batch_size = 5
|
||||
total_outputs = range_size * num_repeats
|
||||
num_outputs_drop_remainder = total_outputs // batch_size
|
||||
num_outputs_keep_remainder = int(math.ceil(total_outputs / batch_size))
|
||||
num_parallel_batches = 2
|
||||
|
||||
def build_ds(range_start, drop_remainder=False):
|
||||
|
||||
def _map_fn(x):
|
||||
return math_ops.square(x)
|
||||
|
||||
ds = dataset_ops.Dataset.range(
|
||||
range_start, range_start + range_size).repeat(num_repeats).apply(
|
||||
batching.map_and_batch(
|
||||
map_func=_map_fn,
|
||||
batch_size=batch_size,
|
||||
num_parallel_batches=num_parallel_batches,
|
||||
drop_remainder=drop_remainder))
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_numa_aware = True
|
||||
return ds.with_options(options)
|
||||
|
||||
self.run_core_tests(lambda: build_ds(10), lambda: build_ds(15),
|
||||
num_outputs_keep_remainder)
|
||||
self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True),
|
||||
num_outputs_drop_remainder)
|
||||
|
||||
def testNumParallelCalls(self):
|
||||
range_size = 11
|
||||
num_repeats = 2
|
||||
batch_size = 5
|
||||
total_outputs = range_size * num_repeats
|
||||
num_outputs_drop_remainder = total_outputs // batch_size
|
||||
num_outputs_keep_remainder = int(math.ceil(total_outputs / batch_size))
|
||||
num_parallel_calls = 7
|
||||
|
||||
def build_ds(range_start, drop_remainder=False):
|
||||
|
||||
def _map_fn(x):
|
||||
return math_ops.square(x)
|
||||
|
||||
ds = dataset_ops.Dataset.range(
|
||||
range_start, range_start + range_size).repeat(num_repeats).apply(
|
||||
batching.map_and_batch(
|
||||
map_func=_map_fn,
|
||||
batch_size=batch_size,
|
||||
num_parallel_calls=num_parallel_calls,
|
||||
drop_remainder=drop_remainder))
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_numa_aware = True
|
||||
return ds.with_options(options)
|
||||
|
||||
self.run_core_tests(lambda: build_ds(10), lambda: build_ds(15),
|
||||
num_outputs_keep_remainder)
|
||||
self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True),
|
||||
num_outputs_drop_remainder)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -122,6 +122,7 @@ py_library(
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:dataset_ops_gen",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:experimental_dataset_ops_gen",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
|
@ -1410,6 +1410,8 @@ class Options(object):
|
||||
"Whether to eliminate no-op transformations."),
|
||||
("experimental_shuffle_and_repeat_fusion", bool,
|
||||
"Whether to fuse shuffle and repeat transformations."),
|
||||
("experimental_numa_aware", bool,
|
||||
"Whether to use NUMA-aware operations."),
|
||||
]:
|
||||
|
||||
def _make_getter(name): # pylint: disable=no-self-argument
|
||||
@ -1458,6 +1460,9 @@ class Options(object):
|
||||
for exp_opt in experimental_optimizations:
|
||||
if getattr(self, "experimental_" + exp_opt):
|
||||
result.append(exp_opt)
|
||||
|
||||
if getattr(self, "experimental_numa_aware"):
|
||||
result.append("map_and_batch_numa_aware_replacement")
|
||||
return result
|
||||
|
||||
def merge(self, options):
|
||||
@ -1485,7 +1490,7 @@ class Options(object):
|
||||
"experimental_map_and_filter_fusion", "experimental_map_fusion",
|
||||
"experimental_map_parallelization", "experimental_map_vectorization",
|
||||
"experimental_noop_elimination",
|
||||
"experimental_shuffle_and_repeat_fusion"
|
||||
"experimental_shuffle_and_repeat_fusion", "experimental_numa_aware",
|
||||
]:
|
||||
this = getattr(result, name)
|
||||
that = getattr(other, name)
|
||||
|
@ -42,6 +42,10 @@ tf_class {
|
||||
name: "experimental_noop_elimination"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "experimental_numa_aware"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "experimental_shuffle_and_repeat_fusion"
|
||||
mtype: "<type \'property\'>"
|
||||
|
@ -42,6 +42,10 @@ tf_class {
|
||||
name: "experimental_noop_elimination"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "experimental_numa_aware"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "experimental_shuffle_and_repeat_fusion"
|
||||
mtype: "<type \'property\'>"
|
||||
|
Loading…
Reference in New Issue
Block a user