[rollback] Support new pipelines in autosharding by including it in FILE autosharding policy
PiperOrigin-RevId: 336902256 Change-Id: I977591868b46405e57612251777fdab4206c4d71
This commit is contained in:
parent
d9ad5ce61b
commit
6df9f5a51d
tensorflow
core/grappler/optimizers/data
python/data/experimental/kernel_tests
@ -45,9 +45,6 @@ constexpr char kShuffleDatasetV3OpName[] = "ShuffleDatasetV3";
|
||||
constexpr char kPrefetchDatasetOpName[] = "PrefetchDataset";
|
||||
constexpr char kRebatchDatasetOpName[] = "RebatchDataset";
|
||||
constexpr char kRebatchDatasetV2OpName[] = "RebatchDatasetV2";
|
||||
constexpr char kTensorDatasetOpName[] = "TensorDataset";
|
||||
constexpr char kTensorSliceDatasetOpName[] = "TensorSliceDataset";
|
||||
constexpr char kPlaceholderOpName[] = "Placeholder";
|
||||
|
||||
constexpr char kNumWorkersAttrName[] = "num_workers";
|
||||
constexpr char kNumReplicasAttrName[] = "num_replicas";
|
||||
@ -71,13 +68,12 @@ constexpr std::array<const char*, 2> kMultipleInputsDatasetOps = {
|
||||
"ZipDataset"
|
||||
};
|
||||
|
||||
constexpr std::array<const char*, 26> kPassThroughOps = {
|
||||
constexpr std::array<const char*, 25> kPassThroughOps = {
|
||||
"_Retval",
|
||||
"AssertNextDataset",
|
||||
"BatchDataset",
|
||||
"CacheDataset",
|
||||
"ExperimentalMapAndBatchDataset",
|
||||
"ExperimentalParseExampleDataset",
|
||||
"ExperimentalRebatchDataset",
|
||||
"FilterDataset",
|
||||
"Identity",
|
||||
@ -417,33 +413,6 @@ Status ProcessDatasetSourceNode(MutableGraphView* graph, const NodeDef& node,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const NodeDef* FindFuncAndTensorSliceDataset(
|
||||
const NodeDef* node, int64 num_workers, int64 index,
|
||||
FunctionLibraryDefinition* flib, MutableGraphView* graph,
|
||||
absl::flat_hash_set<string>* nodes_to_delete) {
|
||||
if (IsDatasetNodeOfType(*node, kFuncDatasetOps)) {
|
||||
const NodeDef* input_node = graph_utils::GetInputNode(*node, *graph, 0);
|
||||
if (input_node->op() == kTensorSliceDatasetOpName ||
|
||||
input_node->op() == kTensorDatasetOpName) {
|
||||
const NodeDef* next_input_node =
|
||||
graph_utils::GetInputNode(*input_node, *graph, 0);
|
||||
if (next_input_node->op() == kPlaceholderOpName) {
|
||||
return node;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!IsDatasetNodeOfType(*node, kPassThroughOps)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Sometimes there are other nodes between the last InterleaveDataset and the
|
||||
// second to last FlatMapDataset, so we need to skip over those.
|
||||
const NodeDef* input_node = graph_utils::GetInputNode(*node, *graph, 0);
|
||||
return FindFuncAndTensorSliceDataset(input_node, num_workers, index, flib,
|
||||
graph, nodes_to_delete);
|
||||
}
|
||||
|
||||
Status RecursivelyHandleOp(const NodeDef& node, int64 num_workers, int64 index,
|
||||
FunctionLibraryDefinition* flib,
|
||||
MutableGraphView* graph,
|
||||
@ -472,39 +441,6 @@ Status RecursivelyHandleOp(const NodeDef& node, int64 num_workers, int64 index,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// This handles the case for the following subgraph:
|
||||
// Placeholder -> TensorSliceDataset -> FlatMapDataset -x->
|
||||
// (other preprocessing datasets) -> InterleaveDataset
|
||||
// and then inserting the shard node immediately after the FlatMapDataset.
|
||||
//
|
||||
// This is used for some training pipelines where a dataset is created with
|
||||
// the following code:
|
||||
//
|
||||
// def make_dataset_pipeline():
|
||||
// file_globs = [...]
|
||||
// datasets = []
|
||||
// for file_glob in file_globs:
|
||||
// datasets.append(Dataset.list_files(file_glob).map(TFRecordReader))
|
||||
// dataset = Dataset.from_tensor_slices(datasets)
|
||||
// dataset = dataset.flat_map(lambda x: x)
|
||||
// dataset = ... # additional preprocessing
|
||||
// dataset = dataset.interleave(lambda x: x, cycle_length=...)
|
||||
// return dataset
|
||||
if (IsDatasetNodeOfType(node, kFuncDatasetOps)) {
|
||||
const NodeDef* input_node = graph_utils::GetInputNode(node, *graph, 0);
|
||||
const NodeDef* flat_map_node = FindFuncAndTensorSliceDataset(
|
||||
input_node, num_workers, index, flib, graph, nodes_to_delete);
|
||||
|
||||
if (flat_map_node != nullptr) {
|
||||
auto fanouts = graph->GetFanouts(*flat_map_node, false);
|
||||
// FlatMapDataset should only be the input to one other dataset.
|
||||
if (fanouts.size() == 1) {
|
||||
return ProcessDatasetSourceNode(graph, *fanouts.begin()->node,
|
||||
nodes_to_delete, num_workers, index);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// This handles the case where a reader Dataset is contained within a
|
||||
// FuncDataset (e.g. FlatMap, ParallelInterleave, etc...). For example:
|
||||
//
|
||||
@ -634,6 +570,7 @@ Status OptimizeGraph(const GrapplerItem& item, int64 num_workers, int64 index,
|
||||
MutableGraphView graph(output);
|
||||
FunctionLibraryDefinition flib(OpRegistry::Global(), item.graph.library());
|
||||
|
||||
|
||||
NodeDef* sink_node;
|
||||
TF_RETURN_IF_ERROR(graph_utils::GetFetchNode(graph, item, &sink_node));
|
||||
|
||||
|
@ -103,43 +103,6 @@ class AutoShardDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
|
||||
]
|
||||
self.assertDatasetProducesWithShuffle(dataset, expected, 5, 4, shuffle)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
combinations.combine(batch_size=[1, 3, 10])))
|
||||
def testDatasetOfReaderDatasetsPipeline(self, batch_size):
|
||||
# This tests a scenario where a list_files main return multiple files
|
||||
# due to the glob containing wildcards.
|
||||
def batch(iterator, n):
|
||||
l = len(iterator)
|
||||
for i in range(0, l, n):
|
||||
yield iterator[i:min(i + n, l)]
|
||||
|
||||
datasets = []
|
||||
for files in batch(self.test_filenames, batch_size):
|
||||
datasets.append(
|
||||
dataset_ops.Dataset.list_files(files, shuffle=False).map(
|
||||
core_readers.TFRecordDataset))
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices(datasets)
|
||||
dataset = dataset.flat_map(lambda x: x)
|
||||
|
||||
# Simulate additional ops in between flat_map and interleave. This should be
|
||||
# a no-op since if ShardDataset is placed right after flat_map, we will only
|
||||
# have two datasets left at this point.
|
||||
dataset = dataset.prefetch(1)
|
||||
dataset = dataset.prefetch(1)
|
||||
|
||||
dataset = dataset.interleave(
|
||||
lambda x: x, cycle_length=1, num_parallel_calls=1)
|
||||
|
||||
dataset = distribute._AutoShardDataset(dataset, 5, 0)
|
||||
expected = [
|
||||
b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension
|
||||
for f in (0, 5)
|
||||
for r in range(0, 10)
|
||||
]
|
||||
|
||||
self.assertDatasetProduces(dataset, expected)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testZipReaderPipeline(self):
|
||||
dataset1 = dataset_ops.Dataset.list_files(
|
||||
|
Loading…
Reference in New Issue
Block a user