[rollback] Support new pipelines in autosharding by including it in FILE autosharding policy

PiperOrigin-RevId: 336902256
Change-Id: I977591868b46405e57612251777fdab4206c4d71
This commit is contained in:
Frank Chen 2020-10-13 10:16:33 -07:00 committed by TensorFlower Gardener
parent d9ad5ce61b
commit 6df9f5a51d
2 changed files with 2 additions and 102 deletions
tensorflow
core/grappler/optimizers/data
python/data/experimental/kernel_tests

View File

@ -45,9 +45,6 @@ constexpr char kShuffleDatasetV3OpName[] = "ShuffleDatasetV3";
constexpr char kPrefetchDatasetOpName[] = "PrefetchDataset";
constexpr char kRebatchDatasetOpName[] = "RebatchDataset";
constexpr char kRebatchDatasetV2OpName[] = "RebatchDatasetV2";
constexpr char kTensorDatasetOpName[] = "TensorDataset";
constexpr char kTensorSliceDatasetOpName[] = "TensorSliceDataset";
constexpr char kPlaceholderOpName[] = "Placeholder";
constexpr char kNumWorkersAttrName[] = "num_workers";
constexpr char kNumReplicasAttrName[] = "num_replicas";
@ -71,13 +68,12 @@ constexpr std::array<const char*, 2> kMultipleInputsDatasetOps = {
"ZipDataset"
};
constexpr std::array<const char*, 26> kPassThroughOps = {
constexpr std::array<const char*, 25> kPassThroughOps = {
"_Retval",
"AssertNextDataset",
"BatchDataset",
"CacheDataset",
"ExperimentalMapAndBatchDataset",
"ExperimentalParseExampleDataset",
"ExperimentalRebatchDataset",
"FilterDataset",
"Identity",
@ -417,33 +413,6 @@ Status ProcessDatasetSourceNode(MutableGraphView* graph, const NodeDef& node,
return Status::OK();
}
const NodeDef* FindFuncAndTensorSliceDataset(
const NodeDef* node, int64 num_workers, int64 index,
FunctionLibraryDefinition* flib, MutableGraphView* graph,
absl::flat_hash_set<string>* nodes_to_delete) {
if (IsDatasetNodeOfType(*node, kFuncDatasetOps)) {
const NodeDef* input_node = graph_utils::GetInputNode(*node, *graph, 0);
if (input_node->op() == kTensorSliceDatasetOpName ||
input_node->op() == kTensorDatasetOpName) {
const NodeDef* next_input_node =
graph_utils::GetInputNode(*input_node, *graph, 0);
if (next_input_node->op() == kPlaceholderOpName) {
return node;
}
}
}
if (!IsDatasetNodeOfType(*node, kPassThroughOps)) {
return nullptr;
}
// Sometimes there are other nodes between the last InterleaveDataset and the
// second to last FlatMapDataset, so we need to skip over those.
const NodeDef* input_node = graph_utils::GetInputNode(*node, *graph, 0);
return FindFuncAndTensorSliceDataset(input_node, num_workers, index, flib,
graph, nodes_to_delete);
}
Status RecursivelyHandleOp(const NodeDef& node, int64 num_workers, int64 index,
FunctionLibraryDefinition* flib,
MutableGraphView* graph,
@ -472,39 +441,6 @@ Status RecursivelyHandleOp(const NodeDef& node, int64 num_workers, int64 index,
return Status::OK();
}
// This handles the case for the following subgraph:
// Placeholder -> TensorSliceDataset -> FlatMapDataset -x->
// (other preprocessing datasets) -> InterleaveDataset
// and then inserting the shard node immediately after the FlatMapDataset.
//
// This is used for some training pipelines where a dataset is created with
// the following code:
//
// def make_dataset_pipeline():
// file_globs = [...]
// datasets = []
// for file_glob in file_globs:
// datasets.append(Dataset.list_files(file_glob).map(TFRecordReader))
// dataset = Dataset.from_tensor_slices(datasets)
// dataset = dataset.flat_map(lambda x: x)
// dataset = ... # additional preprocessing
// dataset = dataset.interleave(lambda x: x, cycle_length=...)
// return dataset
if (IsDatasetNodeOfType(node, kFuncDatasetOps)) {
const NodeDef* input_node = graph_utils::GetInputNode(node, *graph, 0);
const NodeDef* flat_map_node = FindFuncAndTensorSliceDataset(
input_node, num_workers, index, flib, graph, nodes_to_delete);
if (flat_map_node != nullptr) {
auto fanouts = graph->GetFanouts(*flat_map_node, false);
// FlatMapDataset should only be the input to one other dataset.
if (fanouts.size() == 1) {
return ProcessDatasetSourceNode(graph, *fanouts.begin()->node,
nodes_to_delete, num_workers, index);
}
}
}
// This handles the case where a reader Dataset is contained within a
// FuncDataset (e.g. FlatMap, ParallelInterleave, etc...). For example:
//
@ -634,6 +570,7 @@ Status OptimizeGraph(const GrapplerItem& item, int64 num_workers, int64 index,
MutableGraphView graph(output);
FunctionLibraryDefinition flib(OpRegistry::Global(), item.graph.library());
NodeDef* sink_node;
TF_RETURN_IF_ERROR(graph_utils::GetFetchNode(graph, item, &sink_node));

View File

@ -103,43 +103,6 @@ class AutoShardDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
]
self.assertDatasetProducesWithShuffle(dataset, expected, 5, 4, shuffle)
@combinations.generate(
combinations.times(test_base.default_test_combinations(),
combinations.combine(batch_size=[1, 3, 10])))
def testDatasetOfReaderDatasetsPipeline(self, batch_size):
# This tests a scenario where a list_files main return multiple files
# due to the glob containing wildcards.
def batch(iterator, n):
l = len(iterator)
for i in range(0, l, n):
yield iterator[i:min(i + n, l)]
datasets = []
for files in batch(self.test_filenames, batch_size):
datasets.append(
dataset_ops.Dataset.list_files(files, shuffle=False).map(
core_readers.TFRecordDataset))
dataset = dataset_ops.Dataset.from_tensor_slices(datasets)
dataset = dataset.flat_map(lambda x: x)
# Simulate additional ops in between flat_map and interleave. This should be
# a no-op since if ShardDataset is placed right after flat_map, we will only
# have two datasets left at this point.
dataset = dataset.prefetch(1)
dataset = dataset.prefetch(1)
dataset = dataset.interleave(
lambda x: x, cycle_length=1, num_parallel_calls=1)
dataset = distribute._AutoShardDataset(dataset, 5, 0)
expected = [
b"Record %d of file %d" % (r, f) # pylint:disable=g-complex-comprehension
for f in (0, 5)
for r in range(0, 10)
]
self.assertDatasetProduces(dataset, expected)
@combinations.generate(test_base.default_test_combinations())
def testZipReaderPipeline(self):
dataset1 = dataset_ops.Dataset.list_files(