Add API to control parallel interleave prefetching.
PiperOrigin-RevId: 294766661 Change-Id: I8061629522d19d408cd8b7a1981836a4ee958110
This commit is contained in:
		
							parent
							
								
									a7f1d52b03
								
							
						
					
					
						commit
						0c1ca5c674
					
				@ -0,0 +1,100 @@
 | 
			
		||||
op {
 | 
			
		||||
  graph_op_name: "ParallelInterleaveDatasetV4"
 | 
			
		||||
  visibility: HIDDEN
 | 
			
		||||
  in_arg {
 | 
			
		||||
    name: "input_dataset"
 | 
			
		||||
    description: <<END
 | 
			
		||||
Dataset that produces a stream of arguments for the function `f`.
 | 
			
		||||
END
 | 
			
		||||
  }
 | 
			
		||||
  in_arg {
 | 
			
		||||
    name: "other_arguments"
 | 
			
		||||
    description: <<END
 | 
			
		||||
Additional arguments to pass to `f` beyond those produced by `input_dataset`.
 | 
			
		||||
Evaluated once when the dataset is instantiated.
 | 
			
		||||
END
 | 
			
		||||
  }
 | 
			
		||||
  in_arg {
 | 
			
		||||
    name: "cycle_length"
 | 
			
		||||
    description: <<END
 | 
			
		||||
Number of datasets (each created by applying `f` to the elements of
 | 
			
		||||
`input_dataset`) among which the `ParallelInterleaveDatasetV2` will cycle in a
 | 
			
		||||
round-robin fashion.
 | 
			
		||||
END
 | 
			
		||||
  }
 | 
			
		||||
  in_arg {
 | 
			
		||||
    name: "block_length"
 | 
			
		||||
    description: <<END
 | 
			
		||||
Number of elements at a time to produce from each interleaved invocation of a
 | 
			
		||||
dataset returned by `f`.
 | 
			
		||||
END
 | 
			
		||||
  }
 | 
			
		||||
  in_arg {
 | 
			
		||||
    name: "buffer_output_elements"
 | 
			
		||||
    description: <<END
 | 
			
		||||
The number of elements each iterator being interleaved should buffer (similar
 | 
			
		||||
to the `.prefetch()` transformation for each interleaved iterator).
 | 
			
		||||
END
 | 
			
		||||
  }
 | 
			
		||||
  in_arg {
 | 
			
		||||
    name: "prefetch_input_elements"
 | 
			
		||||
    description: <<END
 | 
			
		||||
Determines the number of iterators to prefetch, allowing buffers to warm up and
 | 
			
		||||
data to be pre-fetched without blocking the main thread.
 | 
			
		||||
END
 | 
			
		||||
  }
 | 
			
		||||
  in_arg {
 | 
			
		||||
    name: "num_parallel_calls"
 | 
			
		||||
    description: <<END
 | 
			
		||||
Determines the number of threads that should be used for fetching data from
 | 
			
		||||
input datasets in parallel. The Python API `tf.data.experimental.AUTOTUNE`
 | 
			
		||||
constant can be used to indicate that the level of parallelism should be autotuned.
 | 
			
		||||
END
 | 
			
		||||
  }
 | 
			
		||||
  attr {
 | 
			
		||||
    name: "f"
 | 
			
		||||
    description: <<END
 | 
			
		||||
A function mapping elements of `input_dataset`, concatenated with
 | 
			
		||||
`other_arguments`, to a Dataset variant that contains elements matching
 | 
			
		||||
`output_types` and `output_shapes`.
 | 
			
		||||
END
 | 
			
		||||
  }
 | 
			
		||||
  attr {
 | 
			
		||||
    name: "deterministic"
 | 
			
		||||
    description: <<END
 | 
			
		||||
A string indicating the op-level determinism to use. Deterministic controls
 | 
			
		||||
whether the interleave is allowed to return elements out of order if the next
 | 
			
		||||
element to be returned isn't available, but a later element is. Options are
 | 
			
		||||
"true", "false", and "default". "default" indicates that determinism should be
 | 
			
		||||
decided by the `experimental_deterministic` parameter of `tf.data.Options`.
 | 
			
		||||
END
 | 
			
		||||
  }
 | 
			
		||||
  attr {
 | 
			
		||||
    name: "Targuments"
 | 
			
		||||
    description: <<END
 | 
			
		||||
Types of the elements of `other_arguments`.
 | 
			
		||||
END
 | 
			
		||||
  }
 | 
			
		||||
  attr {
 | 
			
		||||
    name: "output_types"
 | 
			
		||||
  }
 | 
			
		||||
  attr {
 | 
			
		||||
    name: "output_shapes"
 | 
			
		||||
  }
 | 
			
		||||
  summary: "Creates a dataset that applies `f` to the outputs of `input_dataset`."
 | 
			
		||||
  description: <<END
 | 
			
		||||
The resulting dataset is similar to the `InterleaveDataset`, except that the
 | 
			
		||||
dataset will fetch records from the interleaved datasets in parallel.
 | 
			
		||||
 | 
			
		||||
The `tf.data` Python API creates instances of this op from
 | 
			
		||||
`Dataset.interleave()` when the `num_parallel_calls` parameter of that method
 | 
			
		||||
is set to any value other than `None`.
 | 
			
		||||
 | 
			
		||||
By default, the output of this dataset will be deterministic, which may result
 | 
			
		||||
in the dataset blocking if the next data item to be returned isn't available.
 | 
			
		||||
In order to avoid head-of-line blocking, one can either set the `deterministic`
 | 
			
		||||
attribute to "false", or leave it as "default" and set the
 | 
			
		||||
`experimental_deterministic` parameter of `tf.data.Options` to `False`.
 | 
			
		||||
This can improve performance at the expense of non-determinism.
 | 
			
		||||
END
 | 
			
		||||
}
 | 
			
		||||
@ -89,13 +89,14 @@ constexpr std::array<const char*, 28> kPassThroughOps = {
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// TODO(frankchn): Process functions within kFuncDatasetOps as well.
 | 
			
		||||
constexpr std::array<const char*, 6> kFuncDatasetOps = {
 | 
			
		||||
constexpr std::array<const char*, 7> kFuncDatasetOps = {
 | 
			
		||||
    "ExperimentalParallelInterleaveDataset",
 | 
			
		||||
    "FlatMapDataset",
 | 
			
		||||
    "InterleaveDataset",
 | 
			
		||||
    "ParallelInterleaveDataset",
 | 
			
		||||
    "ParallelInterleaveDatasetV2",
 | 
			
		||||
    "ParallelInterleaveDatasetV3"
 | 
			
		||||
    "ParallelInterleaveDatasetV3",
 | 
			
		||||
    "ParallelInterleaveDatasetV4"
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
constexpr std::array<const char*, 5> kUnshardableSourceDatasetOps = {
 | 
			
		||||
 | 
			
		||||
@ -33,12 +33,10 @@ namespace {
 | 
			
		||||
constexpr char kLegacyAutotune[] = "legacy_autotune";
 | 
			
		||||
constexpr char kPrefetchDataset[] = "PrefetchDataset";
 | 
			
		||||
 | 
			
		||||
constexpr std::array<const char*, 5> kAsyncDatasetOps = {
 | 
			
		||||
    "ExperimentalMapAndBatchDataset",
 | 
			
		||||
    "ParallelMapDataset",
 | 
			
		||||
    "ParallelInterleaveDatasetV2",
 | 
			
		||||
    "ParallelInterleaveDatasetV3",
 | 
			
		||||
    "MapAndBatchDataset",
 | 
			
		||||
constexpr std::array<const char*, 6> kAsyncDatasetOps = {
 | 
			
		||||
    "ExperimentalMapAndBatchDataset", "ParallelMapDataset",
 | 
			
		||||
    "ParallelInterleaveDatasetV2",    "ParallelInterleaveDatasetV3",
 | 
			
		||||
    "ParallelInterleaveDatasetV4",    "MapAndBatchDataset",
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
@ -32,8 +32,9 @@ constexpr std::array<const char*, 3> kSloppyAttrOps = {
 | 
			
		||||
    "ParseExampleDataset",
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
constexpr std::array<const char*, 1> kDeterministicAttrOps = {
 | 
			
		||||
constexpr std::array<const char*, 2> kDeterministicAttrOps = {
 | 
			
		||||
    "ParallelInterleaveDatasetV3",
 | 
			
		||||
    "ParallelInterleaveDatasetV4",
 | 
			
		||||
};
 | 
			
		||||
}  // anonymous namespace
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -59,6 +59,10 @@ namespace data {
 | 
			
		||||
    ParallelInterleaveDatasetOp::kCycleLength;
 | 
			
		||||
/* static */ constexpr const char* const
 | 
			
		||||
    ParallelInterleaveDatasetOp::kBlockLength;
 | 
			
		||||
/* static */ constexpr const char* const
 | 
			
		||||
    ParallelInterleaveDatasetOp::kBufferOutputElements;
 | 
			
		||||
/* static */ constexpr const char* const
 | 
			
		||||
    ParallelInterleaveDatasetOp::kPrefetchInputElements;
 | 
			
		||||
/* static */ constexpr const char* const
 | 
			
		||||
    ParallelInterleaveDatasetOp::kNumParallelCalls;
 | 
			
		||||
/* static */ constexpr const char* const ParallelInterleaveDatasetOp::kFunc;
 | 
			
		||||
@ -91,34 +95,64 @@ constexpr char kSizeSuffix[] = ".size";
 | 
			
		||||
constexpr char kInputsSuffix[] = ".inputs";
 | 
			
		||||
constexpr char kIsReadySuffix[] = ".is_ready";
 | 
			
		||||
 | 
			
		||||
// `kCyclePrefetchFactor * cycle_length` is the number of future cycle elements
 | 
			
		||||
// that will be prefetched ahead of time. The purpose of prefetching future
 | 
			
		||||
// cycle elements is to overlap expensive initialization (e.g. opening of a
 | 
			
		||||
// remote file) with other computation.
 | 
			
		||||
constexpr double kCyclePrefetchFactor = 2.0L;
 | 
			
		||||
constexpr char kParallelInterleaveDatasetV2[] = "ParallelInterleaveDatasetV2";
 | 
			
		||||
constexpr char kParallelInterleaveDatasetV3[] = "ParallelInterleaveDatasetV3";
 | 
			
		||||
constexpr char kParallelInterleaveDatasetV4[] = "ParallelInterleaveDatasetV4";
 | 
			
		||||
 | 
			
		||||
// `kPerIteratorPrefetchFactor * block_length + 1` is the number of per-iterator
 | 
			
		||||
// results that will be prefetched ahead of time. The `+ 1` is to match the
 | 
			
		||||
// behavior of the original autotune implementation.
 | 
			
		||||
constexpr double kPerIteratorPrefetchFactor = 2.0L;
 | 
			
		||||
// `kCyclePrefetchFactor * cycle_length` is the default number of future cycle
 | 
			
		||||
// elements that will be prefetched ahead of time. The purpose of prefetching
 | 
			
		||||
// future cycle elements is to overlap expensive initialization (e.g. opening of
 | 
			
		||||
// a remote file) with other computation.
 | 
			
		||||
constexpr double kDefaultCyclePrefetchFactor = 2.0L;
 | 
			
		||||
 | 
			
		||||
// `kPerIteratorPrefetchFactor * block_length + 1` is the defualt number of
 | 
			
		||||
// per-iterator results that will be prefetched ahead of time. The `+ 1` is to
 | 
			
		||||
// match the behavior of the original implementation.
 | 
			
		||||
constexpr double kDefaultPerIteratorPrefetchFactor = 2.0L;
 | 
			
		||||
 | 
			
		||||
// Period between reporting dataset statistics.
 | 
			
		||||
constexpr int kStatsReportingPeriodMillis = 1000;
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
int64 ComputeBufferOutputElements(int64 configured_buffer_output_elements,
 | 
			
		||||
                                  int64 block_length) {
 | 
			
		||||
  if (configured_buffer_output_elements != model::kAutotune) {
 | 
			
		||||
    return configured_buffer_output_elements;
 | 
			
		||||
  }
 | 
			
		||||
  return kDefaultPerIteratorPrefetchFactor * block_length + 1;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
int64 ComputePrefetchInputElements(int64 configured_prefetch_input_elements,
 | 
			
		||||
                                   int64 cycle_length) {
 | 
			
		||||
  if (configured_prefetch_input_elements != model::kAutotune) {
 | 
			
		||||
    return configured_prefetch_input_elements;
 | 
			
		||||
  }
 | 
			
		||||
  return kDefaultCyclePrefetchFactor * cycle_length;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
int64 OpVersionFromOpName(absl::string_view op_name) {
 | 
			
		||||
  if (op_name == kParallelInterleaveDatasetV2) {
 | 
			
		||||
    return 2;
 | 
			
		||||
  } else if (op_name == kParallelInterleaveDatasetV3) {
 | 
			
		||||
    return 3;
 | 
			
		||||
  } else {
 | 
			
		||||
    DCHECK_EQ(op_name, kParallelInterleaveDatasetV4);
 | 
			
		||||
    return 4;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
// The motivation for creating an alternative implementation of parallel
 | 
			
		||||
// interleave is to decouple the degree of parallelism from the cycle length.
 | 
			
		||||
// This makes it possible to change the degree of parallelism (e.g. through
 | 
			
		||||
// auto-tuning) without changing the cycle length (which would change the order
 | 
			
		||||
// in which elements are produced).
 | 
			
		||||
//
 | 
			
		||||
// Furthermore, this class favors modularity over extended functionality. In
 | 
			
		||||
// particular, it refrains from implementing configurable buffering of output
 | 
			
		||||
// elements and prefetching of input iterators.
 | 
			
		||||
class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
 | 
			
		||||
 public:
 | 
			
		||||
  Dataset(OpKernelContext* ctx, const DatasetBase* input,
 | 
			
		||||
          std::unique_ptr<CapturedFunction> captured_func, int64 cycle_length,
 | 
			
		||||
          int64 block_length, int64 num_parallel_calls,
 | 
			
		||||
          int64 block_length, int64 buffer_output_elements,
 | 
			
		||||
          int64 prefetch_input_elements, int64 num_parallel_calls,
 | 
			
		||||
          DeterminismPolicy deterministic, const DataTypeVector& output_types,
 | 
			
		||||
          const std::vector<PartialTensorShape>& output_shapes, int op_version)
 | 
			
		||||
      : DatasetBase(DatasetContext(ctx)),
 | 
			
		||||
@ -126,6 +160,10 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
 | 
			
		||||
        captured_func_(std::move(captured_func)),
 | 
			
		||||
        cycle_length_(cycle_length),
 | 
			
		||||
        block_length_(block_length),
 | 
			
		||||
        buffer_output_elements_(
 | 
			
		||||
            ComputeBufferOutputElements(buffer_output_elements, block_length)),
 | 
			
		||||
        prefetch_input_elements_(ComputePrefetchInputElements(
 | 
			
		||||
            prefetch_input_elements, cycle_length)),
 | 
			
		||||
        num_parallel_calls_(num_parallel_calls),
 | 
			
		||||
        deterministic_(deterministic),
 | 
			
		||||
        output_types_(output_types),
 | 
			
		||||
@ -179,19 +217,44 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
 | 
			
		||||
  Status AsGraphDefInternal(SerializationContext* ctx,
 | 
			
		||||
                            DatasetGraphDefBuilder* b,
 | 
			
		||||
                            Node** output) const override {
 | 
			
		||||
    std::vector<std::pair<size_t, Node*>> inputs;
 | 
			
		||||
    std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>> list_inputs;
 | 
			
		||||
    int input_index = 0;
 | 
			
		||||
 | 
			
		||||
    Node* input_node;
 | 
			
		||||
    TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
 | 
			
		||||
    Node* cycle_length_node;
 | 
			
		||||
    TF_RETURN_IF_ERROR(b->AddScalar(cycle_length_, &cycle_length_node));
 | 
			
		||||
    Node* block_length_node;
 | 
			
		||||
    TF_RETURN_IF_ERROR(b->AddScalar(block_length_, &block_length_node));
 | 
			
		||||
    Node* num_parallel_calls_node;
 | 
			
		||||
    TF_RETURN_IF_ERROR(
 | 
			
		||||
        b->AddScalar(num_parallel_calls_, &num_parallel_calls_node));
 | 
			
		||||
    inputs.emplace_back(input_index++, input_node);
 | 
			
		||||
 | 
			
		||||
    std::vector<Node*> other_arguments;
 | 
			
		||||
    DataTypeVector other_arguments_types;
 | 
			
		||||
    TF_RETURN_IF_ERROR(captured_func_->AddToGraph(ctx, b, &other_arguments,
 | 
			
		||||
                                                  &other_arguments_types));
 | 
			
		||||
    list_inputs.emplace_back(input_index++, other_arguments);
 | 
			
		||||
 | 
			
		||||
    Node* cycle_length_node;
 | 
			
		||||
    TF_RETURN_IF_ERROR(b->AddScalar(cycle_length_, &cycle_length_node));
 | 
			
		||||
    inputs.emplace_back(input_index++, cycle_length_node);
 | 
			
		||||
 | 
			
		||||
    Node* block_length_node;
 | 
			
		||||
    TF_RETURN_IF_ERROR(b->AddScalar(block_length_, &block_length_node));
 | 
			
		||||
    inputs.emplace_back(input_index++, block_length_node);
 | 
			
		||||
 | 
			
		||||
    if (op_version_ >= 4) {
 | 
			
		||||
      Node* buffer_output_elements_node;
 | 
			
		||||
      TF_RETURN_IF_ERROR(
 | 
			
		||||
          b->AddScalar(buffer_output_elements_, &buffer_output_elements_node));
 | 
			
		||||
      inputs.emplace_back(input_index++, buffer_output_elements_node);
 | 
			
		||||
 | 
			
		||||
      Node* prefetch_input_elements_node;
 | 
			
		||||
      TF_RETURN_IF_ERROR(b->AddScalar(prefetch_input_elements_,
 | 
			
		||||
                                      &prefetch_input_elements_node));
 | 
			
		||||
      inputs.emplace_back(input_index++, prefetch_input_elements_node);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    Node* num_parallel_calls_node;
 | 
			
		||||
    TF_RETURN_IF_ERROR(
 | 
			
		||||
        b->AddScalar(num_parallel_calls_, &num_parallel_calls_node));
 | 
			
		||||
    inputs.emplace_back(input_index++, num_parallel_calls_node);
 | 
			
		||||
 | 
			
		||||
    std::vector<std::pair<StringPiece, AttrValue>> attrs;
 | 
			
		||||
    AttrValue f;
 | 
			
		||||
@ -207,18 +270,13 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
 | 
			
		||||
      b->BuildAttrValue(deterministic_.IsNondeterministic(), &sloppy_attr);
 | 
			
		||||
      attrs.emplace_back(kSloppy, sloppy_attr);
 | 
			
		||||
    }
 | 
			
		||||
    if (op_version_ == 3) {
 | 
			
		||||
    if (op_version_ >= 3) {
 | 
			
		||||
      AttrValue deterministic_attr;
 | 
			
		||||
      b->BuildAttrValue(deterministic_.String(), &deterministic_attr);
 | 
			
		||||
      attrs.emplace_back(kDeterministic, deterministic_attr);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    TF_RETURN_IF_ERROR(b->AddDataset(this,
 | 
			
		||||
                                     {{0, input_node},
 | 
			
		||||
                                      {2, cycle_length_node},
 | 
			
		||||
                                      {3, block_length_node},
 | 
			
		||||
                                      {4, num_parallel_calls_node}},
 | 
			
		||||
                                     {{1, other_arguments}}, attrs, output));
 | 
			
		||||
    TF_RETURN_IF_ERROR(b->AddDataset(this, inputs, list_inputs, attrs, output));
 | 
			
		||||
    return Status::OK();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -227,12 +285,6 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
 | 
			
		||||
   public:
 | 
			
		||||
    ParallelInterleaveIterator(const Params& params, bool deterministic)
 | 
			
		||||
        : DatasetIterator<Dataset>(params),
 | 
			
		||||
          per_iterator_prefetch_(
 | 
			
		||||
              static_cast<int>(params.dataset->block_length_ *
 | 
			
		||||
                               kPerIteratorPrefetchFactor) +
 | 
			
		||||
              1),
 | 
			
		||||
          future_elements_prefetch_(static_cast<int>(
 | 
			
		||||
              params.dataset->cycle_length_ * kCyclePrefetchFactor)),
 | 
			
		||||
          mu_(std::make_shared<mutex>()),
 | 
			
		||||
          num_parallel_calls_cond_var_(std::make_shared<condition_variable>()),
 | 
			
		||||
          num_parallel_calls_(std::make_shared<model::SharedState>(
 | 
			
		||||
@ -257,7 +309,8 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
 | 
			
		||||
      // collection, `cycle_length_` threads for the current workers, and
 | 
			
		||||
      // `future_elements_prefetch_` for the future workers.
 | 
			
		||||
      int max_current_workers = dataset()->cycle_length_;
 | 
			
		||||
      int future_workers = future_elements_prefetch_ + dataset()->cycle_length_;
 | 
			
		||||
      int future_workers =
 | 
			
		||||
          dataset()->prefetch_input_elements_ + dataset()->cycle_length_;
 | 
			
		||||
      int num_threads = 1 + max_current_workers + future_workers;
 | 
			
		||||
      if (ctx->stats_aggregator()) {
 | 
			
		||||
        num_threads++;
 | 
			
		||||
@ -656,7 +709,8 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
 | 
			
		||||
      // `cycle_length_` future workers to guarantee that whenever
 | 
			
		||||
      // `future_element_.size() < future_elements_prefetch_`, there will be a
 | 
			
		||||
      // future worker available to create a new future element.
 | 
			
		||||
      int future_workers = future_elements_prefetch_ + dataset()->cycle_length_;
 | 
			
		||||
      int future_workers =
 | 
			
		||||
          dataset()->prefetch_input_elements_ + dataset()->cycle_length_;
 | 
			
		||||
      {
 | 
			
		||||
        mutex_lock l(*mu_);
 | 
			
		||||
        initial_current_workers = num_parallel_calls_->value;
 | 
			
		||||
@ -800,9 +854,9 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
 | 
			
		||||
              current_workers_cond_var_.notify_one();
 | 
			
		||||
            }
 | 
			
		||||
          }
 | 
			
		||||
          while (!cancelled_ &&
 | 
			
		||||
                 (future_elements_.size() >= future_elements_prefetch_ ||
 | 
			
		||||
                  wait_for_checkpoint_)) {
 | 
			
		||||
          while (!cancelled_ && (future_elements_.size() >=
 | 
			
		||||
                                     dataset()->prefetch_input_elements_ ||
 | 
			
		||||
                                 wait_for_checkpoint_)) {
 | 
			
		||||
            WaitWorkerThread(&future_workers_cond_var_, &l);
 | 
			
		||||
          }
 | 
			
		||||
          if (cancelled_) {
 | 
			
		||||
@ -860,7 +914,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
 | 
			
		||||
        mutex_lock l(*mu_);
 | 
			
		||||
        element->results.push_back(std::move(result));
 | 
			
		||||
        NotifyElementUpdate(element);
 | 
			
		||||
        if (element->results.size() == per_iterator_prefetch_) {
 | 
			
		||||
        if (element->results.size() == dataset()->buffer_output_elements_) {
 | 
			
		||||
          break;
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
@ -954,7 +1008,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
 | 
			
		||||
        return true;
 | 
			
		||||
      }
 | 
			
		||||
      return element->iterator &&
 | 
			
		||||
             element->results.size() < per_iterator_prefetch_;
 | 
			
		||||
             element->results.size() < dataset()->buffer_output_elements_;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    inline void IncrementCurrentWorkers() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
 | 
			
		||||
@ -1311,9 +1365,6 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
 | 
			
		||||
    // its elements are null.
 | 
			
		||||
    int64 last_valid_current_element_ GUARDED_BY(mu_) = -1;
 | 
			
		||||
 | 
			
		||||
    const int per_iterator_prefetch_;
 | 
			
		||||
    const int future_elements_prefetch_;
 | 
			
		||||
 | 
			
		||||
    // Identifies whether the current_elements_ vector has been initialized.
 | 
			
		||||
    bool initial_elements_created_ GUARDED_BY(mu_) = false;
 | 
			
		||||
 | 
			
		||||
@ -1421,6 +1472,8 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
 | 
			
		||||
  const std::unique_ptr<CapturedFunction> captured_func_;
 | 
			
		||||
  const int64 cycle_length_;
 | 
			
		||||
  const int64 block_length_;
 | 
			
		||||
  const int64 buffer_output_elements_;
 | 
			
		||||
  const int64 prefetch_input_elements_;
 | 
			
		||||
  const int64 num_parallel_calls_;
 | 
			
		||||
  const DeterminismPolicy deterministic_;
 | 
			
		||||
  const DataTypeVector output_types_;
 | 
			
		||||
@ -1431,7 +1484,8 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
 | 
			
		||||
 | 
			
		||||
ParallelInterleaveDatasetOp::ParallelInterleaveDatasetOp(
 | 
			
		||||
    OpKernelConstruction* ctx)
 | 
			
		||||
    : UnaryDatasetOpKernel(ctx), op_version_(ctx->HasAttr(kSloppy) ? 2 : 3) {
 | 
			
		||||
    : UnaryDatasetOpKernel(ctx),
 | 
			
		||||
      op_version_(OpVersionFromOpName(ctx->def().op())) {
 | 
			
		||||
  FunctionMetadata::Params params;
 | 
			
		||||
  params.is_multi_device_function = true;
 | 
			
		||||
  OP_REQUIRES_OK(ctx,
 | 
			
		||||
@ -1448,7 +1502,7 @@ ParallelInterleaveDatasetOp::ParallelInterleaveDatasetOp(
 | 
			
		||||
      deterministic_ = DeterminismPolicy(DeterminismPolicy::Type::kDefault);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  if (op_version_ == 3) {
 | 
			
		||||
  if (op_version_ >= 3) {
 | 
			
		||||
    std::string deterministic;
 | 
			
		||||
    OP_REQUIRES_OK(ctx, ctx->GetAttr(kDeterministic, &deterministic));
 | 
			
		||||
    OP_REQUIRES_OK(
 | 
			
		||||
@ -1472,6 +1526,26 @@ void ParallelInterleaveDatasetOp::MakeDataset(OpKernelContext* ctx,
 | 
			
		||||
  OP_REQUIRES(ctx, block_length > 0,
 | 
			
		||||
              errors::InvalidArgument("`block_length` must be > 0"));
 | 
			
		||||
 | 
			
		||||
  int64 buffer_output_elements = model::kAutotune;
 | 
			
		||||
  int64 prefetch_input_elements = model::kAutotune;
 | 
			
		||||
  if (op_version_ >= 4) {
 | 
			
		||||
    OP_REQUIRES(ctx,
 | 
			
		||||
                buffer_output_elements == model::kAutotune ||
 | 
			
		||||
                    buffer_output_elements >= 0,
 | 
			
		||||
                errors::InvalidArgument("`buffer_output_elements` must be ",
 | 
			
		||||
                                        model::kAutotune, " or >= 0 but is ",
 | 
			
		||||
                                        buffer_output_elements));
 | 
			
		||||
 | 
			
		||||
    OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kPrefetchInputElements,
 | 
			
		||||
                                            &prefetch_input_elements));
 | 
			
		||||
    OP_REQUIRES(ctx,
 | 
			
		||||
                prefetch_input_elements == model::kAutotune ||
 | 
			
		||||
                    prefetch_input_elements >= 0,
 | 
			
		||||
                errors::InvalidArgument("`prefetch_input_elements` must be ",
 | 
			
		||||
                                        model::kAutotune, " or >= 0 but is ",
 | 
			
		||||
                                        prefetch_input_elements));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  int64 num_parallel_calls = 0;
 | 
			
		||||
  OP_REQUIRES_OK(
 | 
			
		||||
      ctx, ParseScalarArgument(ctx, kNumParallelCalls, &num_parallel_calls));
 | 
			
		||||
@ -1492,18 +1566,22 @@ void ParallelInterleaveDatasetOp::MakeDataset(OpKernelContext* ctx,
 | 
			
		||||
    metrics::RecordTFDataAutotune(kDatasetType);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  *output = new Dataset(ctx, input, std::move(captured_func), cycle_length,
 | 
			
		||||
                        block_length, num_parallel_calls, deterministic_,
 | 
			
		||||
                        output_types_, output_shapes_, op_version_);
 | 
			
		||||
  *output = new Dataset(
 | 
			
		||||
      ctx, input, std::move(captured_func), cycle_length, block_length,
 | 
			
		||||
      buffer_output_elements, prefetch_input_elements, num_parallel_calls,
 | 
			
		||||
      deterministic_, output_types_, output_shapes_, op_version_);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDatasetV2").Device(DEVICE_CPU),
 | 
			
		||||
REGISTER_KERNEL_BUILDER(Name(kParallelInterleaveDatasetV2).Device(DEVICE_CPU),
 | 
			
		||||
                        ParallelInterleaveDatasetOp);
 | 
			
		||||
REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDatasetV3").Device(DEVICE_CPU),
 | 
			
		||||
REGISTER_KERNEL_BUILDER(Name(kParallelInterleaveDatasetV3).Device(DEVICE_CPU),
 | 
			
		||||
                        ParallelInterleaveDatasetOp);
 | 
			
		||||
REGISTER_INPUT_COLOCATION_EXEMPTION("ParallelInterleaveDatasetV2");
 | 
			
		||||
REGISTER_INPUT_COLOCATION_EXEMPTION("ParallelInterleaveDatasetV3");
 | 
			
		||||
REGISTER_KERNEL_BUILDER(Name(kParallelInterleaveDatasetV4).Device(DEVICE_CPU),
 | 
			
		||||
                        ParallelInterleaveDatasetOp);
 | 
			
		||||
REGISTER_INPUT_COLOCATION_EXEMPTION(kParallelInterleaveDatasetV2);
 | 
			
		||||
REGISTER_INPUT_COLOCATION_EXEMPTION(kParallelInterleaveDatasetV3);
 | 
			
		||||
REGISTER_INPUT_COLOCATION_EXEMPTION(kParallelInterleaveDatasetV4);
 | 
			
		||||
}  // namespace
 | 
			
		||||
}  // namespace data
 | 
			
		||||
}  // namespace tensorflow
 | 
			
		||||
 | 
			
		||||
@ -29,6 +29,10 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
 | 
			
		||||
  static constexpr const char* const kOtherArguments = "other_arguments";
 | 
			
		||||
  static constexpr const char* const kCycleLength = "cycle_length";
 | 
			
		||||
  static constexpr const char* const kBlockLength = "block_length";
 | 
			
		||||
  static constexpr const char* const kBufferOutputElements =
 | 
			
		||||
      "buffer_output_elements";
 | 
			
		||||
  static constexpr const char* const kPrefetchInputElements =
 | 
			
		||||
      "prefetch_input_elements";
 | 
			
		||||
  static constexpr const char* const kNumParallelCalls = "num_parallel_calls";
 | 
			
		||||
  static constexpr const char* const kFunc = "f";
 | 
			
		||||
  static constexpr const char* const kTarguments = "Targuments";
 | 
			
		||||
 | 
			
		||||
@ -18,14 +18,15 @@ namespace data {
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
constexpr char kNodeName[] = "parallel_interleave_dataset";
 | 
			
		||||
constexpr int kOpVersion = 3;
 | 
			
		||||
constexpr int kOpVersion = 4;
 | 
			
		||||
 | 
			
		||||
class ParallelInterleaveDatasetParams : public DatasetParams {
 | 
			
		||||
 public:
 | 
			
		||||
  template <typename T>
 | 
			
		||||
  ParallelInterleaveDatasetParams(
 | 
			
		||||
      T input_dataset_params, std::vector<Tensor> other_arguments,
 | 
			
		||||
      int64 cycle_length, int64 block_length, int64 num_parallel_calls,
 | 
			
		||||
      int64 cycle_length, int64 block_length, int64 buffer_output_elements,
 | 
			
		||||
      int64 prefetch_input_elements, int64 num_parallel_calls,
 | 
			
		||||
      FunctionDefHelper::AttrValueWrapper func,
 | 
			
		||||
      std::vector<FunctionDef> func_lib, DataTypeVector type_arguments,
 | 
			
		||||
      const DataTypeVector& output_dtypes,
 | 
			
		||||
@ -36,6 +37,8 @@ class ParallelInterleaveDatasetParams : public DatasetParams {
 | 
			
		||||
        other_arguments_(std::move(other_arguments)),
 | 
			
		||||
        cycle_length_(cycle_length),
 | 
			
		||||
        block_length_(block_length),
 | 
			
		||||
        buffer_output_elements_(buffer_output_elements),
 | 
			
		||||
        prefetch_input_elements_(prefetch_input_elements),
 | 
			
		||||
        num_parallel_calls_(num_parallel_calls),
 | 
			
		||||
        func_(std::move(func)),
 | 
			
		||||
        func_lib_(std::move(func_lib)),
 | 
			
		||||
@ -56,6 +59,10 @@ class ParallelInterleaveDatasetParams : public DatasetParams {
 | 
			
		||||
        CreateTensor<int64>(TensorShape({}), {cycle_length_}));
 | 
			
		||||
    input_tensors.emplace_back(
 | 
			
		||||
        CreateTensor<int64>(TensorShape({}), {block_length_}));
 | 
			
		||||
    input_tensors.emplace_back(
 | 
			
		||||
        CreateTensor<int64>(TensorShape({}), {buffer_output_elements_}));
 | 
			
		||||
    input_tensors.emplace_back(
 | 
			
		||||
        CreateTensor<int64>(TensorShape({}), {prefetch_input_elements_}));
 | 
			
		||||
    input_tensors.emplace_back(
 | 
			
		||||
        CreateTensor<int64>(TensorShape({}), {num_parallel_calls_}));
 | 
			
		||||
    return input_tensors;
 | 
			
		||||
@ -69,6 +76,10 @@ class ParallelInterleaveDatasetParams : public DatasetParams {
 | 
			
		||||
    }
 | 
			
		||||
    input_names->emplace_back(ParallelInterleaveDatasetOp::kCycleLength);
 | 
			
		||||
    input_names->emplace_back(ParallelInterleaveDatasetOp::kBlockLength);
 | 
			
		||||
    input_names->emplace_back(
 | 
			
		||||
        ParallelInterleaveDatasetOp::kBufferOutputElements);
 | 
			
		||||
    input_names->emplace_back(
 | 
			
		||||
        ParallelInterleaveDatasetOp::kPrefetchInputElements);
 | 
			
		||||
    input_names->emplace_back(ParallelInterleaveDatasetOp::kNumParallelCalls);
 | 
			
		||||
    return Status::OK();
 | 
			
		||||
  }
 | 
			
		||||
@ -93,6 +104,8 @@ class ParallelInterleaveDatasetParams : public DatasetParams {
 | 
			
		||||
  std::vector<Tensor> other_arguments_;
 | 
			
		||||
  int64 cycle_length_;
 | 
			
		||||
  int64 block_length_;
 | 
			
		||||
  int64 buffer_output_elements_;
 | 
			
		||||
  int64 prefetch_input_elements_;
 | 
			
		||||
  int64 num_parallel_calls_;
 | 
			
		||||
  FunctionDefHelper::AttrValueWrapper func_;
 | 
			
		||||
  std::vector<FunctionDef> func_lib_;
 | 
			
		||||
@ -111,8 +124,6 @@ FunctionDefHelper::AttrValueWrapper MakeTensorSliceDatasetFunc(
 | 
			
		||||
                 {"output_shapes", output_shapes}});
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// test case 1: cycle_length = 1, block_length = 1, num_parallel_calls = 1,
 | 
			
		||||
// sloppy = false
 | 
			
		||||
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams1() {
 | 
			
		||||
  auto tensor_slice_dataset_params = TensorSliceDatasetParams(
 | 
			
		||||
      /*components=*/{CreateTensor<int64>(TensorShape{3, 3, 1},
 | 
			
		||||
@ -123,6 +134,8 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams1() {
 | 
			
		||||
      /*other_arguments=*/{},
 | 
			
		||||
      /*cycle_length=*/1,
 | 
			
		||||
      /*block_length=*/1,
 | 
			
		||||
      /*buffer_output_elements=*/model::kAutotune,
 | 
			
		||||
      /*prefetch_input_elements=*/model::kAutotune,
 | 
			
		||||
      /*num_parallel_calls=*/1,
 | 
			
		||||
      /*func=*/
 | 
			
		||||
      MakeTensorSliceDatasetFunc(
 | 
			
		||||
@ -136,8 +149,6 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams1() {
 | 
			
		||||
      /*node_name=*/kNodeName);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// test case 2: cycle_length = 2, block_length = 1, num_parallel_calls = 2,
 | 
			
		||||
// sloppy = false
 | 
			
		||||
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams2() {
 | 
			
		||||
  auto tensor_slice_dataset_params = TensorSliceDatasetParams(
 | 
			
		||||
      /*components=*/{CreateTensor<int64>(TensorShape{3, 3, 1},
 | 
			
		||||
@ -148,6 +159,8 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams2() {
 | 
			
		||||
      /*other_arguments=*/{},
 | 
			
		||||
      /*cycle_length=*/2,
 | 
			
		||||
      /*block_length=*/1,
 | 
			
		||||
      /*buffer_output_elements=*/0,
 | 
			
		||||
      /*prefetch_input_elements=*/0,
 | 
			
		||||
      /*num_parallel_calls=*/2,
 | 
			
		||||
      /*func=*/
 | 
			
		||||
      MakeTensorSliceDatasetFunc(
 | 
			
		||||
@ -161,8 +174,6 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams2() {
 | 
			
		||||
      /*node_name=*/kNodeName);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// test case 3: cycle_length = 3, block_length = 1, num_parallel_calls = 2,
 | 
			
		||||
// sloppy = true
 | 
			
		||||
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams3() {
 | 
			
		||||
  auto tensor_slice_dataset_params = TensorSliceDatasetParams(
 | 
			
		||||
      /*components=*/{CreateTensor<int64>(TensorShape{3, 3, 1},
 | 
			
		||||
@ -173,6 +184,8 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams3() {
 | 
			
		||||
      /*other_arguments=*/{},
 | 
			
		||||
      /*cycle_length=*/3,
 | 
			
		||||
      /*block_length=*/1,
 | 
			
		||||
      /*buffer_output_elements=*/0,
 | 
			
		||||
      /*prefetch_input_elements=*/1,
 | 
			
		||||
      /*num_parallel_calls=*/2,
 | 
			
		||||
      /*func=*/
 | 
			
		||||
      MakeTensorSliceDatasetFunc(
 | 
			
		||||
@ -186,9 +199,6 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams3() {
 | 
			
		||||
      /*node_name=*/kNodeName);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// test case 4: cycle_length = 5, block_length = 1, num_parallel_calls = 4,
 | 
			
		||||
// sloppy = true
 | 
			
		||||
 | 
			
		||||
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams4() {
 | 
			
		||||
  auto tensor_slice_dataset_params = TensorSliceDatasetParams(
 | 
			
		||||
      /*components=*/{CreateTensor<int64>(TensorShape{3, 3, 1},
 | 
			
		||||
@ -199,6 +209,8 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams4() {
 | 
			
		||||
      /*other_arguments=*/{},
 | 
			
		||||
      /*cycle_length=*/5,
 | 
			
		||||
      /*block_length=*/1,
 | 
			
		||||
      /*buffer_output_elements=*/1,
 | 
			
		||||
      /*prefetch_input_elements=*/0,
 | 
			
		||||
      /*num_parallel_calls=*/4,
 | 
			
		||||
      /*func=*/
 | 
			
		||||
      MakeTensorSliceDatasetFunc(
 | 
			
		||||
@ -212,8 +224,6 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams4() {
 | 
			
		||||
      /*node_name=*/kNodeName);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// test case 5: cycle_length = 2, block_length = 2, num_parallel_calls = 1,
 | 
			
		||||
// sloppy = false
 | 
			
		||||
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams5() {
 | 
			
		||||
  auto tensor_slice_dataset_params = TensorSliceDatasetParams(
 | 
			
		||||
      /*components=*/{CreateTensor<tstring>(
 | 
			
		||||
@ -224,6 +234,8 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams5() {
 | 
			
		||||
      /*other_arguments=*/{},
 | 
			
		||||
      /*cycle_length=*/2,
 | 
			
		||||
      /*block_length=*/2,
 | 
			
		||||
      /*buffer_output_elements=*/2,
 | 
			
		||||
      /*prefetch_input_elements=*/2,
 | 
			
		||||
      /*num_parallel_calls=*/1,
 | 
			
		||||
      /*func=*/
 | 
			
		||||
      MakeTensorSliceDatasetFunc(
 | 
			
		||||
@ -237,8 +249,6 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams5() {
 | 
			
		||||
      /*node_name=*/kNodeName);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// test case 6: cycle_length = 2, block_length = 3, num_parallel_calls = 2,
 | 
			
		||||
// sloppy = true
 | 
			
		||||
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams6() {
 | 
			
		||||
  auto tensor_slice_dataset_params = TensorSliceDatasetParams(
 | 
			
		||||
      /*components=*/{CreateTensor<tstring>(
 | 
			
		||||
@ -249,6 +259,8 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams6() {
 | 
			
		||||
      /*other_arguments=*/{},
 | 
			
		||||
      /*cycle_length=*/2,
 | 
			
		||||
      /*block_length=*/3,
 | 
			
		||||
      /*buffer_output_elements=*/100,
 | 
			
		||||
      /*prefetch_input_elements=*/100,
 | 
			
		||||
      /*num_parallel_calls=*/2,
 | 
			
		||||
      /*func=*/
 | 
			
		||||
      MakeTensorSliceDatasetFunc(
 | 
			
		||||
@ -262,8 +274,6 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams6() {
 | 
			
		||||
      /*node_name=*/kNodeName);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// test case 7: cycle_length = 3, block_length = 2, num_parallel_calls = 2,
 | 
			
		||||
// sloppy = false
 | 
			
		||||
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams7() {
 | 
			
		||||
  auto tensor_slice_dataset_params = TensorSliceDatasetParams(
 | 
			
		||||
      /*components=*/{CreateTensor<tstring>(
 | 
			
		||||
@ -274,6 +284,8 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams7() {
 | 
			
		||||
      /*other_arguments=*/{},
 | 
			
		||||
      /*cycle_length=*/3,
 | 
			
		||||
      /*block_length=*/2,
 | 
			
		||||
      /*buffer_output_elements=*/model::kAutotune,
 | 
			
		||||
      /*prefetch_input_elements=*/model::kAutotune,
 | 
			
		||||
      /*num_parallel_calls=*/2,
 | 
			
		||||
      /*func=*/
 | 
			
		||||
      MakeTensorSliceDatasetFunc(
 | 
			
		||||
@ -287,8 +299,6 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams7() {
 | 
			
		||||
      /*node_name=*/kNodeName);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// test case 8: cycle_length = 3, block_length = 3, num_parallel_calls = 3,
 | 
			
		||||
// sloppy = true
 | 
			
		||||
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams8() {
 | 
			
		||||
  auto tensor_slice_dataset_params = TensorSliceDatasetParams(
 | 
			
		||||
      /*components=*/{CreateTensor<tstring>(
 | 
			
		||||
@ -299,6 +309,8 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams8() {
 | 
			
		||||
      /*other_arguments=*/{},
 | 
			
		||||
      /*cycle_length=*/3,
 | 
			
		||||
      /*block_length=*/3,
 | 
			
		||||
      /*buffer_output_elements=*/model::kAutotune,
 | 
			
		||||
      /*prefetch_input_elements=*/model::kAutotune,
 | 
			
		||||
      /*num_parallel_calls=*/3,
 | 
			
		||||
      /*func=*/
 | 
			
		||||
      MakeTensorSliceDatasetFunc(
 | 
			
		||||
@ -312,8 +324,6 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams8() {
 | 
			
		||||
      /*node_name=*/kNodeName);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// test case 9: cycle_length = 4, block_length = 4, num_parallel_calls = 4,
 | 
			
		||||
// sloppy = true
 | 
			
		||||
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams9() {
 | 
			
		||||
  auto tensor_slice_dataset_params = TensorSliceDatasetParams(
 | 
			
		||||
      /*components=*/{CreateTensor<tstring>(
 | 
			
		||||
@ -324,6 +334,8 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams9() {
 | 
			
		||||
      /*other_arguments=*/{},
 | 
			
		||||
      /*cycle_length=*/4,
 | 
			
		||||
      /*block_length=*/4,
 | 
			
		||||
      /*buffer_output_elements=*/model::kAutotune,
 | 
			
		||||
      /*prefetch_input_elements=*/model::kAutotune,
 | 
			
		||||
      /*num_parallel_calls=*/4,
 | 
			
		||||
      /*func=*/
 | 
			
		||||
      MakeTensorSliceDatasetFunc(
 | 
			
		||||
@ -337,8 +349,6 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams9() {
 | 
			
		||||
      /*node_name=*/kNodeName);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// test case 10: cycle_length = 3, block_length = 3,
 | 
			
		||||
// num_parallel_calls = kAutotune, sloppy = true
 | 
			
		||||
ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams10() {
 | 
			
		||||
  auto tensor_slice_dataset_params = TensorSliceDatasetParams(
 | 
			
		||||
      /*components=*/{CreateTensor<tstring>(
 | 
			
		||||
@ -349,6 +359,8 @@ ParallelInterleaveDatasetParams ParallelInterleaveDatasetParams10() {
 | 
			
		||||
      /*other_arguments=*/{},
 | 
			
		||||
      /*cycle_length=*/4,
 | 
			
		||||
      /*block_length=*/4,
 | 
			
		||||
      /*buffer_output_elements=*/model::kAutotune,
 | 
			
		||||
      /*prefetch_input_elements=*/model::kAutotune,
 | 
			
		||||
      /*num_parallel_calls=*/model::kAutotune,
 | 
			
		||||
      /*func=*/
 | 
			
		||||
      MakeTensorSliceDatasetFunc(
 | 
			
		||||
@ -372,6 +384,8 @@ ParallelInterleaveDatasetParams LongCycleDeteriministicParams() {
 | 
			
		||||
      /*other_arguments=*/{},
 | 
			
		||||
      /*cycle_length=*/11,
 | 
			
		||||
      /*block_length=*/1,
 | 
			
		||||
      /*buffer_output_elements=*/model::kAutotune,
 | 
			
		||||
      /*prefetch_input_elements=*/model::kAutotune,
 | 
			
		||||
      /*num_parallel_calls=*/2,
 | 
			
		||||
      /*func=*/
 | 
			
		||||
      MakeTensorSliceDatasetFunc(
 | 
			
		||||
@ -385,8 +399,6 @@ ParallelInterleaveDatasetParams LongCycleDeteriministicParams() {
 | 
			
		||||
      /*node_name=*/kNodeName);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// test case 11: cycle_length = 0, block_length = 1, num_parallel_calls = 2,
 | 
			
		||||
// sloppy = true
 | 
			
		||||
ParallelInterleaveDatasetParams
 | 
			
		||||
ParallelInterleaveDatasetParamsWithInvalidCycleLength() {
 | 
			
		||||
  auto tensor_slice_dataset_params = TensorSliceDatasetParams(
 | 
			
		||||
@ -398,6 +410,8 @@ ParallelInterleaveDatasetParamsWithInvalidCycleLength() {
 | 
			
		||||
      /*other_arguments=*/{},
 | 
			
		||||
      /*cycle_length=*/0,
 | 
			
		||||
      /*block_length=*/1,
 | 
			
		||||
      /*buffer_output_elements=*/model::kAutotune,
 | 
			
		||||
      /*prefetch_input_elements=*/model::kAutotune,
 | 
			
		||||
      /*num_parallel_calls=*/2,
 | 
			
		||||
      /*func=*/
 | 
			
		||||
      MakeTensorSliceDatasetFunc(
 | 
			
		||||
@ -411,8 +425,6 @@ ParallelInterleaveDatasetParamsWithInvalidCycleLength() {
 | 
			
		||||
      /*node_name=*/kNodeName);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// test case 12: cycle_length = 1, block_length = -1, num_parallel_calls = 2,
 | 
			
		||||
// sloppy = true
 | 
			
		||||
ParallelInterleaveDatasetParams
 | 
			
		||||
ParallelInterleaveDatasetParamsWithInvalidBlockLength() {
 | 
			
		||||
  auto tensor_slice_dataset_params = TensorSliceDatasetParams(
 | 
			
		||||
@ -424,6 +436,8 @@ ParallelInterleaveDatasetParamsWithInvalidBlockLength() {
 | 
			
		||||
      /*other_arguments=*/{},
 | 
			
		||||
      /*cycle_length=*/1,
 | 
			
		||||
      /*block_length=*/-1,
 | 
			
		||||
      /*buffer_output_elements=*/model::kAutotune,
 | 
			
		||||
      /*prefetch_input_elements=*/model::kAutotune,
 | 
			
		||||
      /*num_parallel_calls=*/2,
 | 
			
		||||
      /*func=*/
 | 
			
		||||
      MakeTensorSliceDatasetFunc(
 | 
			
		||||
@ -437,8 +451,6 @@ ParallelInterleaveDatasetParamsWithInvalidBlockLength() {
 | 
			
		||||
      /*node_name=*/kNodeName);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// test case 13: cycle_length = 1, block_length = 1, num_parallel_calls = -5,
 | 
			
		||||
// sloppy = true
 | 
			
		||||
ParallelInterleaveDatasetParams
 | 
			
		||||
ParallelInterleaveDatasetParamsWithInvalidNumParallelCalls() {
 | 
			
		||||
  auto tensor_slice_dataset_params = TensorSliceDatasetParams(
 | 
			
		||||
@ -450,6 +462,60 @@ ParallelInterleaveDatasetParamsWithInvalidNumParallelCalls() {
 | 
			
		||||
      /*other_arguments=*/{},
 | 
			
		||||
      /*cycle_length=*/1,
 | 
			
		||||
      /*block_length=*/1,
 | 
			
		||||
      /*buffer_output_elements=*/model::kAutotune,
 | 
			
		||||
      /*prefetch_input_elements=*/model::kAutotune,
 | 
			
		||||
      /*num_parallel_calls=*/-5,
 | 
			
		||||
      /*func=*/
 | 
			
		||||
      MakeTensorSliceDatasetFunc(
 | 
			
		||||
          DataTypeVector({DT_INT64}),
 | 
			
		||||
          std::vector<PartialTensorShape>({PartialTensorShape({1})})),
 | 
			
		||||
      /*func_lib=*/{test::function::MakeTensorSliceDataset()},
 | 
			
		||||
      /*type_arguments=*/{},
 | 
			
		||||
      /*output_dtypes=*/{DT_INT64},
 | 
			
		||||
      /*output_shapes=*/{PartialTensorShape({1})},
 | 
			
		||||
      /*deterministic=*/DeterminismPolicy::kNondeterministic,
 | 
			
		||||
      /*node_name=*/kNodeName);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
ParallelInterleaveDatasetParams
 | 
			
		||||
ParallelInterleaveDatasetParamsWithInvalidBufferOutputElements() {
 | 
			
		||||
  auto tensor_slice_dataset_params = TensorSliceDatasetParams(
 | 
			
		||||
      /*components=*/{CreateTensor<int64>(TensorShape{3, 3, 1},
 | 
			
		||||
                                          {0, 1, 2, 3, 4, 5, 6, 7, 8})},
 | 
			
		||||
      /*node_name=*/"tensor_slice");
 | 
			
		||||
  return ParallelInterleaveDatasetParams(
 | 
			
		||||
      tensor_slice_dataset_params,
 | 
			
		||||
      /*other_arguments=*/{},
 | 
			
		||||
      /*cycle_length=*/1,
 | 
			
		||||
      /*block_length=*/1,
 | 
			
		||||
      /*buffer_output_elements=*/-2,
 | 
			
		||||
      /*prefetch_input_elements=*/model::kAutotune,
 | 
			
		||||
      /*num_parallel_calls=*/-5,
 | 
			
		||||
      /*func=*/
 | 
			
		||||
      MakeTensorSliceDatasetFunc(
 | 
			
		||||
          DataTypeVector({DT_INT64}),
 | 
			
		||||
          std::vector<PartialTensorShape>({PartialTensorShape({1})})),
 | 
			
		||||
      /*func_lib=*/{test::function::MakeTensorSliceDataset()},
 | 
			
		||||
      /*type_arguments=*/{},
 | 
			
		||||
      /*output_dtypes=*/{DT_INT64},
 | 
			
		||||
      /*output_shapes=*/{PartialTensorShape({1})},
 | 
			
		||||
      /*deterministic=*/DeterminismPolicy::kNondeterministic,
 | 
			
		||||
      /*node_name=*/kNodeName);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
ParallelInterleaveDatasetParams
 | 
			
		||||
ParallelInterleaveDatasetParamsWithInvalidPrefetchInputElements() {
 | 
			
		||||
  auto tensor_slice_dataset_params = TensorSliceDatasetParams(
 | 
			
		||||
      /*components=*/{CreateTensor<int64>(TensorShape{3, 3, 1},
 | 
			
		||||
                                          {0, 1, 2, 3, 4, 5, 6, 7, 8})},
 | 
			
		||||
      /*node_name=*/"tensor_slice");
 | 
			
		||||
  return ParallelInterleaveDatasetParams(
 | 
			
		||||
      tensor_slice_dataset_params,
 | 
			
		||||
      /*other_arguments=*/{},
 | 
			
		||||
      /*cycle_length=*/1,
 | 
			
		||||
      /*block_length=*/1,
 | 
			
		||||
      /*buffer_output_elements=*/-2,
 | 
			
		||||
      /*prefetch_input_elements=*/model::kAutotune,
 | 
			
		||||
      /*num_parallel_calls=*/-5,
 | 
			
		||||
      /*func=*/
 | 
			
		||||
      MakeTensorSliceDatasetFunc(
 | 
			
		||||
@ -698,7 +764,10 @@ TEST_F(ParallelInterleaveDatasetOpTest, InvalidArguments) {
 | 
			
		||||
  std::vector<ParallelInterleaveDatasetParams> invalid_params = {
 | 
			
		||||
      ParallelInterleaveDatasetParamsWithInvalidCycleLength(),
 | 
			
		||||
      ParallelInterleaveDatasetParamsWithInvalidBlockLength(),
 | 
			
		||||
      ParallelInterleaveDatasetParamsWithInvalidNumParallelCalls()};
 | 
			
		||||
      ParallelInterleaveDatasetParamsWithInvalidNumParallelCalls(),
 | 
			
		||||
      ParallelInterleaveDatasetParamsWithInvalidBufferOutputElements(),
 | 
			
		||||
      ParallelInterleaveDatasetParamsWithInvalidPrefetchInputElements(),
 | 
			
		||||
  };
 | 
			
		||||
  for (auto& dataset_params : invalid_params) {
 | 
			
		||||
    EXPECT_EQ(Initialize(dataset_params).code(),
 | 
			
		||||
              tensorflow::error::INVALID_ARGUMENT);
 | 
			
		||||
 | 
			
		||||
@ -227,6 +227,24 @@ REGISTER_OP("ParallelInterleaveDatasetV3")
 | 
			
		||||
    .Attr("output_shapes: list(shape) >= 1")
 | 
			
		||||
    .SetShapeFn(shape_inference::ScalarShape);
 | 
			
		||||
 | 
			
		||||
// Like V3, but adds buffer_output_elements and prefetch_input_elements.
 | 
			
		||||
REGISTER_OP("ParallelInterleaveDatasetV4")
 | 
			
		||||
    .Input("input_dataset: variant")
 | 
			
		||||
    .Input("other_arguments: Targuments")
 | 
			
		||||
    .Input("cycle_length: int64")
 | 
			
		||||
    .Input("block_length: int64")
 | 
			
		||||
    .Input("buffer_output_elements: int64")
 | 
			
		||||
    .Input("prefetch_input_elements: int64")
 | 
			
		||||
    .Input("num_parallel_calls: int64")
 | 
			
		||||
    .Output("handle: variant")
 | 
			
		||||
    .Attr("f: func")
 | 
			
		||||
    // "true", "false", or "default".
 | 
			
		||||
    .Attr("deterministic: string = 'default'")
 | 
			
		||||
    .Attr("Targuments: list(type) >= 0")
 | 
			
		||||
    .Attr("output_types: list(type) >= 1")
 | 
			
		||||
    .Attr("output_shapes: list(shape) >= 1")
 | 
			
		||||
    .SetShapeFn(shape_inference::ScalarShape);
 | 
			
		||||
 | 
			
		||||
REGISTER_OP("FilterDataset")
 | 
			
		||||
    .Input("input_dataset: variant")
 | 
			
		||||
    .Input("other_arguments: Targuments")
 | 
			
		||||
 | 
			
		||||
@ -64,6 +64,8 @@ class InjectPrefetchTest(test_base.DatasetTestBase, parameterized.TestCase):
 | 
			
		||||
    parallel_interleave = "ParallelInterleaveV2"
 | 
			
		||||
    if compat.forward_compatible(2020, 2, 20):
 | 
			
		||||
      parallel_interleave = "ParallelInterleaveV3"
 | 
			
		||||
    if compat.forward_compatible(2020, 3, 6):
 | 
			
		||||
      parallel_interleave = "ParallelInterleaveV4"
 | 
			
		||||
    dataset = dataset.apply(
 | 
			
		||||
        testing.assert_next([parallel_interleave, "Prefetch", "FiniteTake"]))
 | 
			
		||||
    dataset = dataset.interleave(
 | 
			
		||||
@ -79,6 +81,8 @@ class InjectPrefetchTest(test_base.DatasetTestBase, parameterized.TestCase):
 | 
			
		||||
    parallel_interleave = "ParallelInterleaveV2"
 | 
			
		||||
    if compat.forward_compatible(2020, 2, 20):
 | 
			
		||||
      parallel_interleave = "ParallelInterleaveV3"
 | 
			
		||||
    if compat.forward_compatible(2020, 3, 6):
 | 
			
		||||
      parallel_interleave = "ParallelInterleaveV4"
 | 
			
		||||
    dataset = dataset.apply(
 | 
			
		||||
        testing.assert_next([
 | 
			
		||||
            "ParallelMap", "Prefetch", parallel_interleave, "Prefetch",
 | 
			
		||||
 | 
			
		||||
@ -1714,9 +1714,13 @@ name=None))
 | 
			
		||||
    if num_parallel_calls is None:
 | 
			
		||||
      return InterleaveDataset(self, map_func, cycle_length, block_length)
 | 
			
		||||
    else:
 | 
			
		||||
      return ParallelInterleaveDataset(self, map_func, cycle_length,
 | 
			
		||||
                                       block_length, num_parallel_calls,
 | 
			
		||||
                                       deterministic)
 | 
			
		||||
      return ParallelInterleaveDataset(
 | 
			
		||||
          self,
 | 
			
		||||
          map_func,
 | 
			
		||||
          cycle_length,
 | 
			
		||||
          block_length,
 | 
			
		||||
          num_parallel_calls,
 | 
			
		||||
          deterministic=deterministic)
 | 
			
		||||
 | 
			
		||||
  def filter(self, predicate):
 | 
			
		||||
    """Filters this dataset according to `predicate`.
 | 
			
		||||
@ -4042,6 +4046,8 @@ class ParallelInterleaveDataset(UnaryDataset):
 | 
			
		||||
               cycle_length,
 | 
			
		||||
               block_length,
 | 
			
		||||
               num_parallel_calls,
 | 
			
		||||
               buffer_output_elements=AUTOTUNE,
 | 
			
		||||
               prefetch_input_elements=AUTOTUNE,
 | 
			
		||||
               deterministic=None):
 | 
			
		||||
    """See `Dataset.interleave()` for details."""
 | 
			
		||||
    self._input_dataset = input_dataset
 | 
			
		||||
@ -4056,6 +4062,15 @@ class ParallelInterleaveDataset(UnaryDataset):
 | 
			
		||||
        cycle_length, dtype=dtypes.int64, name="cycle_length")
 | 
			
		||||
    self._block_length = ops.convert_to_tensor(
 | 
			
		||||
        block_length, dtype=dtypes.int64, name="block_length")
 | 
			
		||||
    self._buffer_output_elements = ops.convert_to_tensor(
 | 
			
		||||
        buffer_output_elements,
 | 
			
		||||
        dtype=dtypes.int64,
 | 
			
		||||
        name="buffer_output_elements")
 | 
			
		||||
    self._prefetch_input_elements = ops.convert_to_tensor(
 | 
			
		||||
        prefetch_input_elements,
 | 
			
		||||
        dtype=dtypes.int64,
 | 
			
		||||
        name="prefetch_input_elements")
 | 
			
		||||
 | 
			
		||||
    self._num_parallel_calls = ops.convert_to_tensor(
 | 
			
		||||
        num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls")
 | 
			
		||||
    if deterministic is None:
 | 
			
		||||
@ -4065,7 +4080,21 @@ class ParallelInterleaveDataset(UnaryDataset):
 | 
			
		||||
    else:
 | 
			
		||||
      deterministic_string = "false"
 | 
			
		||||
 | 
			
		||||
    if deterministic is not None or compat.forward_compatible(2020, 2, 20):
 | 
			
		||||
    if (buffer_output_elements != AUTOTUNE or
 | 
			
		||||
        prefetch_input_elements != AUTOTUNE or
 | 
			
		||||
        compat.forward_compatible(2020, 3, 6)):
 | 
			
		||||
      variant_tensor = gen_dataset_ops.parallel_interleave_dataset_v4(
 | 
			
		||||
          input_dataset._variant_tensor,  # pylint: disable=protected-access
 | 
			
		||||
          self._map_func.function.captured_inputs,  # pylint: disable=protected-access
 | 
			
		||||
          self._cycle_length,
 | 
			
		||||
          self._block_length,
 | 
			
		||||
          self._buffer_output_elements,
 | 
			
		||||
          self._prefetch_input_elements,
 | 
			
		||||
          self._num_parallel_calls,
 | 
			
		||||
          f=self._map_func.function,
 | 
			
		||||
          deterministic=deterministic_string,
 | 
			
		||||
          **self._flat_structure)
 | 
			
		||||
    elif deterministic is not None or compat.forward_compatible(2020, 2, 20):
 | 
			
		||||
      variant_tensor = gen_dataset_ops.parallel_interleave_dataset_v3(
 | 
			
		||||
          input_dataset._variant_tensor,  # pylint: disable=protected-access
 | 
			
		||||
          self._map_func.function.captured_inputs,  # pylint: disable=protected-access
 | 
			
		||||
 | 
			
		||||
@ -2636,6 +2636,10 @@ tf_module {
 | 
			
		||||
    name: "ParallelInterleaveDatasetV3"
 | 
			
		||||
    argspec: "args=[\'input_dataset\', \'other_arguments\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'deterministic\', \'name\'], varargs=None, keywords=None, defaults=[\'default\', \'None\'], "
 | 
			
		||||
  }
 | 
			
		||||
  member_method {
 | 
			
		||||
    name: "ParallelInterleaveDatasetV4"
 | 
			
		||||
    argspec: "args=[\'input_dataset\', \'other_arguments\', \'cycle_length\', \'block_length\', \'buffer_output_elements\', \'prefetch_input_elements\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'deterministic\', \'name\'], varargs=None, keywords=None, defaults=[\'default\', \'None\'], "
 | 
			
		||||
  }
 | 
			
		||||
  member_method {
 | 
			
		||||
    name: "ParallelMapDataset"
 | 
			
		||||
    argspec: "args=[\'input_dataset\', \'other_arguments\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'use_inter_op_parallelism\', \'sloppy\', \'preserve_cardinality\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'False\', \'False\', \'None\'], "
 | 
			
		||||
 | 
			
		||||
@ -2636,6 +2636,10 @@ tf_module {
 | 
			
		||||
    name: "ParallelInterleaveDatasetV3"
 | 
			
		||||
    argspec: "args=[\'input_dataset\', \'other_arguments\', \'cycle_length\', \'block_length\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'deterministic\', \'name\'], varargs=None, keywords=None, defaults=[\'default\', \'None\'], "
 | 
			
		||||
  }
 | 
			
		||||
  member_method {
 | 
			
		||||
    name: "ParallelInterleaveDatasetV4"
 | 
			
		||||
    argspec: "args=[\'input_dataset\', \'other_arguments\', \'cycle_length\', \'block_length\', \'buffer_output_elements\', \'prefetch_input_elements\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'deterministic\', \'name\'], varargs=None, keywords=None, defaults=[\'default\', \'None\'], "
 | 
			
		||||
  }
 | 
			
		||||
  member_method {
 | 
			
		||||
    name: "ParallelMapDataset"
 | 
			
		||||
    argspec: "args=[\'input_dataset\', \'other_arguments\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'use_inter_op_parallelism\', \'sloppy\', \'preserve_cardinality\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'False\', \'False\', \'None\'], "
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user