[rollback] Support new pipelines in autosharding by including it in FILE autosharding policy
PiperOrigin-RevId: 336902256 Change-Id: I977591868b46405e57612251777fdab4206c4d71
This commit is contained in:
parent
d9ad5ce61b
commit
6df9f5a51d
@ -45,9 +45,6 @@ constexpr char kShuffleDatasetV3OpName[] = "ShuffleDatasetV3";
|
|||||||
constexpr char kPrefetchDatasetOpName[] = "PrefetchDataset";
|
constexpr char kPrefetchDatasetOpName[] = "PrefetchDataset";
|
||||||
constexpr char kRebatchDatasetOpName[] = "RebatchDataset";
|
constexpr char kRebatchDatasetOpName[] = "RebatchDataset";
|
||||||
constexpr char kRebatchDatasetV2OpName[] = "RebatchDatasetV2";
|
constexpr char kRebatchDatasetV2OpName[] = "RebatchDatasetV2";
|
||||||
constexpr char kTensorDatasetOpName[] = "TensorDataset";
|
|
||||||
constexpr char kTensorSliceDatasetOpName[] = "TensorSliceDataset";
|
|
||||||
constexpr char kPlaceholderOpName[] = "Placeholder";
|
|
||||||
|
|
||||||
constexpr char kNumWorkersAttrName[] = "num_workers";
|
constexpr char kNumWorkersAttrName[] = "num_workers";
|
||||||
constexpr char kNumReplicasAttrName[] = "num_replicas";
|
constexpr char kNumReplicasAttrName[] = "num_replicas";
|
||||||
@ -71,13 +68,12 @@ constexpr std::array<const char*, 2> kMultipleInputsDatasetOps = {
|
|||||||
"ZipDataset"
|
"ZipDataset"
|
||||||
};
|
};
|
||||||
|
|
||||||
constexpr std::array<const char*, 26> kPassThroughOps = {
|
constexpr std::array<const char*, 25> kPassThroughOps = {
|
||||||
"_Retval",
|
"_Retval",
|
||||||
"AssertNextDataset",
|
"AssertNextDataset",
|
||||||
"BatchDataset",
|
"BatchDataset",
|
||||||
"CacheDataset",
|
"CacheDataset",
|
||||||
"ExperimentalMapAndBatchDataset",
|
"ExperimentalMapAndBatchDataset",
|
||||||
"ExperimentalParseExampleDataset",
|
|
||||||
"ExperimentalRebatchDataset",
|
"ExperimentalRebatchDataset",
|
||||||
"FilterDataset",
|
"FilterDataset",
|
||||||
"Identity",
|
"Identity",
|
||||||
@ -417,33 +413,6 @@ Status ProcessDatasetSourceNode(MutableGraphView* graph, const NodeDef& node,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
const NodeDef* FindFuncAndTensorSliceDataset(
|
|
||||||
const NodeDef* node, int64 num_workers, int64 index,
|
|
||||||
FunctionLibraryDefinition* flib, MutableGraphView* graph,
|
|
||||||
absl::flat_hash_set<string>* nodes_to_delete) {
|
|
||||||
if (IsDatasetNodeOfType(*node, kFuncDatasetOps)) {
|
|
||||||
const NodeDef* input_node = graph_utils::GetInputNode(*node, *graph, 0);
|
|
||||||
if (input_node->op() == kTensorSliceDatasetOpName ||
|
|
||||||
input_node->op() == kTensorDatasetOpName) {
|
|
||||||
const NodeDef* next_input_node =
|
|
||||||
graph_utils::GetInputNode(*input_node, *graph, 0);
|
|
||||||
if (next_input_node->op() == kPlaceholderOpName) {
|
|
||||||
return node;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!IsDatasetNodeOfType(*node, kPassThroughOps)) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sometimes there are other nodes between the last InterleaveDataset and the
|
|
||||||
// second to last FlatMapDataset, so we need to skip over those.
|
|
||||||
const NodeDef* input_node = graph_utils::GetInputNode(*node, *graph, 0);
|
|
||||||
return FindFuncAndTensorSliceDataset(input_node, num_workers, index, flib,
|
|
||||||
graph, nodes_to_delete);
|
|
||||||
}
|
|
||||||
|
|
||||||
Status RecursivelyHandleOp(const NodeDef& node, int64 num_workers, int64 index,
|
Status RecursivelyHandleOp(const NodeDef& node, int64 num_workers, int64 index,
|
||||||
FunctionLibraryDefinition* flib,
|
FunctionLibraryDefinition* flib,
|
||||||
MutableGraphView* graph,
|
MutableGraphView* graph,
|
||||||
@ -472,39 +441,6 @@ Status RecursivelyHandleOp(const NodeDef& node, int64 num_workers, int64 index,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
// This handles the case for the following subgraph:
|
|
||||||
// Placeholder -> TensorSliceDataset -> FlatMapDataset -x->
|
|
||||||
// (other preprocessing datasets) -> InterleaveDataset
|
|
||||||
// and then inserting the shard node immediately after the FlatMapDataset.
|
|
||||||
//
|
|
||||||
// This is used for some training pipelines where a dataset is created with
|
|
||||||
// the following code:
|
|
||||||
//
|
|
||||||
// def make_dataset_pipeline():
|
|
||||||
// file_globs = [...]
|
|
||||||
// datasets = []
|
|
||||||
// for file_glob in file_globs:
|
|
||||||
// datasets.append(Dataset.list_files(file_glob).map(TFRecordReader))
|
|
||||||
// dataset = Dataset.from_tensor_slices(datasets)
|
|
||||||
// dataset = dataset.flat_map(lambda x: x)
|
|
||||||
// dataset = ... # additional preprocessing
|
|
||||||
// dataset = dataset.interleave(lambda x: x, cycle_length=...)
|
|
||||||
// return dataset
|
|
||||||
if (IsDatasetNodeOfType(node, kFuncDatasetOps)) {
|
|
||||||
const NodeDef* input_node = graph_utils::GetInputNode(node, *graph, 0);
|
|
||||||
const NodeDef* flat_map_node = FindFuncAndTensorSliceDataset(
|
|
||||||
input_node, num_workers, index, flib, graph, nodes_to_delete);
|
|
||||||
|
|
||||||
if (flat_map_node != nullptr) {
|
|
||||||
auto fanouts = graph->GetFanouts(*flat_map_node, false);
|
|
||||||
// FlatMapDataset should only be the input to one other dataset.
|
|
||||||
if (fanouts.size() == 1) {
|
|
||||||
return ProcessDatasetSourceNode(graph, *fanouts.begin()->node,
|
|
||||||
nodes_to_delete, num_workers, index);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// This handles the case where a reader Dataset is contained within a
|
// This handles the case where a reader Dataset is contained within a
|
||||||
// FuncDataset (e.g. FlatMap, ParallelInterleave, etc...). For example:
|
// FuncDataset (e.g. FlatMap, ParallelInterleave, etc...). For example:
|
||||||
//
|
//
|
||||||
@ -634,6 +570,7 @@ Status OptimizeGraph(const GrapplerItem& item, int64 num_workers, int64 index,
|
|||||||
MutableGraphView graph(output);
|
MutableGraphView graph(output);
|
||||||
FunctionLibraryDefinition flib(OpRegistry::Global(), item.graph.library());
|
FunctionLibraryDefinition flib(OpRegistry::Global(), item.graph.library());
|
||||||
|
|
||||||
|
|
||||||
NodeDef* sink_node;
|
NodeDef* sink_node;
|
||||||
TF_RETURN_IF_ERROR(graph_utils::GetFetchNode(graph, item, &sink_node));
|
TF_RETURN_IF_ERROR(graph_utils::GetFetchNode(graph, item, &sink_node));
|
||||||
|
|
||||||
|
@ -103,43 +103,6 @@ class AutoShardDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
|
|||||||
]
|
]
|
||||||
self.assertDatasetProducesWithShuffle(dataset, expected, 5, 4, shuffle)
|
self.assertDatasetProducesWithShuffle(dataset, expected, 5, 4, shuffle)
|
||||||
|
|
||||||
@combinations.generate(
|
|
||||||
combinations.times(test_base.default_test_combinations(),
|
|
||||||
combinations.combine(batch_size=[1, 3, 10])))
|
|
||||||
def testDatasetOfReaderDatasetsPipeline(self, batch_size):
|
|
||||||
# This tests a scenario where a list_files main return multiple files
|
|
||||||
# due to the glob containing wildcards.
|
|
||||||
def batch(iterator, n):
|
|
||||||
l = len(iterator)
|
|
||||||
for i in range(0, l, n):
|
|
||||||
yield iterator[i:min(i + n, l)]
|
|
||||||
|
|
||||||
datasets = []
|
|
||||||
for files in batch(self.test_filenames, batch_size):
|
|
||||||
datasets.append(
|
|
||||||
dataset_ops.Dataset.list_files(files, shuffle=False).map(
|
|
||||||
core_readers.TFRecordDataset))
|
|
||||||
dataset = dataset_ops.Dataset.from_tensor_slices(datasets)
|
|
||||||
dataset = dataset.flat_map(lambda x: x)
|
|
||||||
|
|
||||||
# Simulate additional ops in between flat_map and interleave. This should be
|
|
||||||
# a no-op since if ShardDataset is placed right after flat_map, we will only
|
|
||||||
# have two datasets left at this point.
|
|
||||||
dataset = dataset.prefetch(1)
|
|
||||||
dataset = dataset.prefetch(1)
|
|
||||||
|
|
||||||
dataset = dataset.interleave(
|
|
||||||
lambda x: x, cycle_length=1, num_parallel_calls=1)
|
|
||||||
|
|
||||||
dataset = distribute._AutoShardDataset(dataset, 5, 0)
|
|
||||||
expected = [
|
|
||||||
b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension
|
|
||||||
for f in (0, 5)
|
|
||||||
for r in range(0, 10)
|
|
||||||
]
|
|
||||||
|
|
||||||
self.assertDatasetProduces(dataset, expected)
|
|
||||||
|
|
||||||
@combinations.generate(test_base.default_test_combinations())
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
def testZipReaderPipeline(self):
|
def testZipReaderPipeline(self):
|
||||||
dataset1 = dataset_ops.Dataset.list_files(
|
dataset1 = dataset_ops.Dataset.list_files(
|
||||||
|
Loading…
Reference in New Issue
Block a user