[tf.data] Serialization and checkpointing related cleanup.
This CL: - removes unused `DatasetBase::Save()` and related tests - replaces `SerilizationContext::optimization_only` with multiple functionality specific flags (`check_external_state`, `fail_if_unimplemented`, and `serialize_data_tensors`) - introduces `DatasetBase::CheckExternalState` as an error-raising replacement for `DatasetBase::IsStateful` to make it possible to communicate the reason for why serialization failed through the error status - adds `IteratorBase::SaveInternal` and `IteratorBase::RestoreInternal` in preparation of making these methods pure virtual PiperOrigin-RevId: 262235093
This commit is contained in:
parent
43a408b8ac
commit
6d8f05acd7
@ -115,6 +115,15 @@ class BigtableReaderDatasetIterator : public DatasetIterator<Dataset> {
|
||||
const ::google::cloud::bigtable::Row& row,
|
||||
std::vector<Tensor>* out_tensors) = 0;
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
return errors::Unimplemented("SaveInternal is currently not supported");
|
||||
}
|
||||
|
||||
Status RestoreInternal(IteratorContext* ctx,
|
||||
IteratorStateReader* reader) override {
|
||||
return errors::Unimplemented("RestoreInternal is currently not supported");
|
||||
}
|
||||
|
||||
private:
|
||||
Status EnsureIteratorInitialized() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
if (reader_) {
|
||||
|
@ -97,7 +97,10 @@ class BigtableLookupDatasetOp : public UnaryDatasetOpKernel {
|
||||
return "BigtableLookupDatasetOp::Dataset";
|
||||
}
|
||||
|
||||
bool IsStateful() const override { return true; }
|
||||
Status CheckExternalState() const override {
|
||||
return errors::FailedPrecondition(DebugString(),
|
||||
" depends on external state.");
|
||||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
@ -174,6 +177,17 @@ class BigtableLookupDatasetOp : public UnaryDatasetOpKernel {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
protected:
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
return errors::Unimplemented("SaveInternal is currently not supported");
|
||||
}
|
||||
|
||||
Status RestoreInternal(IteratorContext* ctx,
|
||||
IteratorStateReader* reader) override {
|
||||
return errors::Unimplemented(
|
||||
"RestoreInternal is currently not supported");
|
||||
}
|
||||
|
||||
private:
|
||||
Status ParseRow(IteratorContext* ctx,
|
||||
const ::google::cloud::bigtable::Row& row,
|
||||
|
@ -71,7 +71,10 @@ class BigtablePrefixKeyDatasetOp : public DatasetOpKernel {
|
||||
|
||||
BigtableTableResource* table() const { return table_; }
|
||||
|
||||
bool IsStateful() const override { return true; }
|
||||
Status CheckExternalState() const override {
|
||||
return errors::FailedPrecondition(DebugString(),
|
||||
" depends on external state.");
|
||||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
|
@ -76,7 +76,10 @@ class BigtableRangeKeyDatasetOp : public DatasetOpKernel {
|
||||
|
||||
BigtableTableResource* table() const { return table_; }
|
||||
|
||||
bool IsStateful() const override { return true; }
|
||||
Status CheckExternalState() const override {
|
||||
return errors::FailedPrecondition(DebugString(),
|
||||
" depends on external state.");
|
||||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
|
@ -89,7 +89,10 @@ class BigtableSampleKeyPairsDatasetOp : public DatasetOpKernel {
|
||||
return "BigtableSampleKeyPairsDatasetOp::Dataset";
|
||||
}
|
||||
|
||||
bool IsStateful() const override { return true; }
|
||||
Status CheckExternalState() const override {
|
||||
return errors::FailedPrecondition(DebugString(),
|
||||
" depends on external state.");
|
||||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
@ -187,6 +190,17 @@ class BigtableSampleKeyPairsDatasetOp : public DatasetOpKernel {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
protected:
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
return errors::Unimplemented("SaveInternal is currently not supported");
|
||||
}
|
||||
|
||||
Status RestoreInternal(IteratorContext* ctx,
|
||||
IteratorStateReader* reader) override {
|
||||
return errors::Unimplemented(
|
||||
"RestoreInternal is currently not supported");
|
||||
}
|
||||
|
||||
private:
|
||||
mutex mu_;
|
||||
size_t index_ GUARDED_BY(mu_) = 0;
|
||||
|
@ -64,7 +64,10 @@ class BigtableSampleKeysDatasetOp : public DatasetOpKernel {
|
||||
|
||||
BigtableTableResource* table() const { return table_; }
|
||||
|
||||
bool IsStateful() const override { return true; }
|
||||
Status CheckExternalState() const override {
|
||||
return errors::FailedPrecondition(DebugString(),
|
||||
" depends on external state.");
|
||||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
@ -109,6 +112,17 @@ class BigtableSampleKeysDatasetOp : public DatasetOpKernel {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
protected:
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
return errors::Unimplemented("SaveInternal is currently not supported");
|
||||
}
|
||||
|
||||
Status RestoreInternal(IteratorContext* ctx,
|
||||
IteratorStateReader* reader) override {
|
||||
return errors::Unimplemented(
|
||||
"RestoreInternal is currently not supported");
|
||||
}
|
||||
|
||||
private:
|
||||
mutex mu_;
|
||||
size_t index_ = 0;
|
||||
|
@ -131,7 +131,10 @@ class BigtableScanDatasetOp : public DatasetOpKernel {
|
||||
|
||||
BigtableTableResource* table() const { return table_; }
|
||||
|
||||
bool IsStateful() const override { return true; }
|
||||
Status CheckExternalState() const override {
|
||||
return errors::FailedPrecondition(DebugString(),
|
||||
" depends on external state.");
|
||||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
|
@ -264,9 +264,6 @@ Status GraphDefBuilderWrapper::AddFunction(
|
||||
<< " the graph. It will not be added again.";
|
||||
return Status::OK();
|
||||
}
|
||||
if (!ctx->optimization_only()) {
|
||||
TF_RETURN_IF_ERROR(EnsureFunctionIsStateless(function_name, lib_def));
|
||||
}
|
||||
const FunctionDef* f_def = lib_def.Find(function_name);
|
||||
if (f_def == nullptr) {
|
||||
return errors::InvalidArgument("Unable to find FunctionDef for ",
|
||||
@ -369,29 +366,10 @@ Status StoreDatasetInVariantTensor(DatasetBase* dataset, Tensor* tensor) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DatasetBase::Save(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) const {
|
||||
string serialized_graph_def;
|
||||
string output_node;
|
||||
GraphDefBuilder b;
|
||||
DatasetGraphDefBuilder db(&b);
|
||||
Node* node = nullptr;
|
||||
TF_RETURN_IF_ERROR(AsGraphDefInternal(ctx, &db, &node));
|
||||
output_node = node->name();
|
||||
GraphDef graph_def;
|
||||
TF_RETURN_IF_ERROR(b.ToGraphDef(&graph_def));
|
||||
graph_def.SerializeToString(&serialized_graph_def);
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(kDatasetGraphKey, serialized_graph_def));
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(kDatasetGraphOutputNodeKey, output_node));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DatasetBase::DatasetGraphDefBuilder::AddInputDataset(
|
||||
SerializationContext* ctx, const DatasetBase* dataset, Node** output) {
|
||||
Status status = dataset->AsGraphDefInternal(ctx, this, output);
|
||||
if (ctx->optimization_only() && errors::IsUnimplemented(status)) {
|
||||
if (errors::IsUnimplemented(status) && !ctx->fail_if_unimplemented()) {
|
||||
Tensor t(DT_VARIANT, TensorShape({}));
|
||||
// `StoreDatasetInVariantTensor` will transfer ownership of `dataset`. We
|
||||
// increment the refcount of `dataset` here to retain ownership.
|
||||
|
@ -201,48 +201,6 @@ class GraphDefBuilderWrapper {
|
||||
private:
|
||||
void AddPlaceholderInternal(const Tensor& val, Node** output);
|
||||
void AddTensorInternal(const Tensor& val, Node** output);
|
||||
|
||||
Status EnsureFunctionIsStateless(
|
||||
const string& function_name,
|
||||
const FunctionLibraryDefinition& lib_def) const {
|
||||
const FunctionDef* function_def = lib_def.Find(function_name);
|
||||
if (!function_def) {
|
||||
return errors::InvalidArgument("Unable to find FunctionDef for ",
|
||||
function_name, " in registry.");
|
||||
}
|
||||
for (const NodeDef& node_def : function_def->node_def()) {
|
||||
const OpDef* op_def;
|
||||
TF_RETURN_IF_ERROR(lib_def.LookUpOpDef(node_def.op(), &op_def));
|
||||
// TODO(b/65524810): Hack to allow functions to capture Dataset op
|
||||
// nodes needed for FlatMap. Currently, source datasets nodes have been
|
||||
// marked stateful to avoid constant folding since we do not have a
|
||||
// good way of serializing them.
|
||||
if (IsOpWhitelisted(op_def)) {
|
||||
continue;
|
||||
}
|
||||
if (op_def->is_stateful()) {
|
||||
return errors::InvalidArgument(
|
||||
"Op[name: ", node_def.name(), ", type: ", node_def.op(), "] ",
|
||||
"in function ", function_name, " is stateful. ",
|
||||
"Saving stateful functions is not supported yet.");
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Returns whether an op has been whitelisted for use inside map_fns.
|
||||
// Uses a heuristic to whitelist source dataset ops which have been
|
||||
// marked stateful due to b/65524810.
|
||||
// Also looks up the `op_def->name` in the global
|
||||
// `WhitelistedStatefulOpRegistry`.
|
||||
bool IsOpWhitelisted(const OpDef* op_def) const {
|
||||
return ((absl::EndsWith(op_def->name(), "Dataset") ||
|
||||
absl::EndsWith(op_def->name(), "DatasetV2")) &&
|
||||
op_def->output_arg_size() == 1 &&
|
||||
op_def->output_arg(0).type() == DT_VARIANT) ||
|
||||
WhitelistedStatefulOpRegistry::Global()->Contains(op_def->name());
|
||||
}
|
||||
|
||||
bool HasAttr(const string& op_type_name, const string& attr_name) const;
|
||||
|
||||
bool HasAttr(const OpDef* op_def, const string& attr_name) const {
|
||||
@ -466,7 +424,23 @@ class SerializationContext {
|
||||
public:
|
||||
struct Params {
|
||||
std::vector<std::pair<string, Tensor>>* input_list = nullptr; // Not owned.
|
||||
bool optimization_only = false;
|
||||
|
||||
// Indicates whether serialization should check if the dataset depends on
|
||||
// external state. If the check is enabled and external state is
|
||||
// encountered, then the serialization will fail.
|
||||
bool check_external_state = true;
|
||||
|
||||
// Indicates whether an attempt to serialize a dataset that does not
|
||||
// implement serialization should result in an error. If set to `false`, the
|
||||
// serialized graph will replace the dataset with a placeholder returned in
|
||||
// `input_list`.
|
||||
bool fail_if_unimplemented = true;
|
||||
|
||||
// Indicates whether (potentionally large) data tensors should be
|
||||
// serialized, or replaced with a placeholder returned in `input_list`. The
|
||||
// latter makes sense to do when performing data agnostic graph rewrites to
|
||||
// reduce the memory usage.
|
||||
bool serialize_data_tensors = true;
|
||||
};
|
||||
|
||||
explicit SerializationContext(Params params) : params_(std::move(params)) {}
|
||||
@ -475,7 +449,11 @@ class SerializationContext {
|
||||
return params_.input_list;
|
||||
}
|
||||
|
||||
bool optimization_only() { return params_.optimization_only; }
|
||||
bool check_external_state() const { return params_.check_external_state; }
|
||||
|
||||
bool fail_if_unimplemented() const { return params_.fail_if_unimplemented; }
|
||||
|
||||
bool serialize_data_tensors() const { return params_.serialize_data_tensors; }
|
||||
|
||||
private:
|
||||
Params params_;
|
||||
@ -550,10 +528,6 @@ class IteratorBase {
|
||||
return RestoreInternal(ctx, reader);
|
||||
}
|
||||
|
||||
Status Restore(IteratorContext&& ctx, IteratorStateReader* reader) {
|
||||
return Restore(&ctx, reader);
|
||||
}
|
||||
|
||||
protected:
|
||||
// Returns a node that models this iterator.
|
||||
virtual std::shared_ptr<model::Node> CreateNode(
|
||||
@ -573,12 +547,22 @@ class IteratorBase {
|
||||
return input->RestoreInternal(ctx, reader);
|
||||
}
|
||||
|
||||
// Saves the state of this iterator recursively.
|
||||
// Saves the state of this iterator.
|
||||
//
|
||||
// This method is used to store the state of the iterator in a checkpoint.
|
||||
//
|
||||
// TODO(jsimsa): Make this method pure virtual once all `IteratorBase`
|
||||
// implementations have an override.
|
||||
virtual Status SaveInternal(IteratorStateWriter* writer) {
|
||||
return errors::Unimplemented("SaveInternal");
|
||||
}
|
||||
|
||||
// Restores the state of this iterator recursively.
|
||||
// Restores the state of this iterator.
|
||||
//
|
||||
// This method is used to restore the state of the iterator from a checkpoint.
|
||||
//
|
||||
// TODO(jsimsa): Make this method pure virtual once all `IteratorBase`
|
||||
// implementations have an override.
|
||||
virtual Status RestoreInternal(IteratorContext* ctx,
|
||||
IteratorStateReader* reader) {
|
||||
return errors::Unimplemented("RestoreInternal");
|
||||
@ -718,17 +702,25 @@ class DatasetBase : public core::RefCounted {
|
||||
// A human-readable debug string for this dataset.
|
||||
virtual string DebugString() const = 0;
|
||||
|
||||
// Serializes the dataset and writes it to the `writer`.
|
||||
virtual Status Save(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) const;
|
||||
// If the dataset is stateful it will not be possible to save its graph or
|
||||
// checkpoint the state of its iterators.
|
||||
//
|
||||
// TODO(jsimsa): Remove this method once all `DatasetBase` implementations are
|
||||
// migrated over to `CheckExternalState`.
|
||||
virtual bool IsStateful() const { return false; }
|
||||
|
||||
// Indicates whether the dataset depends on external mutable state case in
|
||||
// which case the serialization of the input pipeline graph and the
|
||||
// checkpointing of the input pipeline state will not be supported.
|
||||
// Indicates whether the dataset depends on any external state. If so, the
|
||||
// method returns `errors::FailedPrecondition` with a message that identifies
|
||||
// the external state. Otherwise, the method returns `Status::OK()`.
|
||||
//
|
||||
// TODO(jsimsa): Make this method pure virtual once all `DatasetBase`
|
||||
// implementations have an override.
|
||||
virtual bool IsStateful() const { return false; }
|
||||
virtual Status CheckExternalState() const {
|
||||
if (IsStateful()) {
|
||||
return errors::FailedPrecondition("Dataset cannot be serialized.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
protected:
|
||||
friend Status AsGraphDef(
|
||||
@ -739,11 +731,22 @@ class DatasetBase : public core::RefCounted {
|
||||
|
||||
class DatasetGraphDefBuilder : public GraphDefBuilderWrapper {
|
||||
public:
|
||||
DatasetGraphDefBuilder(GraphDefBuilder* b) : GraphDefBuilderWrapper(b) {}
|
||||
explicit DatasetGraphDefBuilder(GraphDefBuilder* b)
|
||||
: GraphDefBuilderWrapper(b) {}
|
||||
Status AddInputDataset(SerializationContext* ctx,
|
||||
const DatasetBase* dataset, Node** output);
|
||||
};
|
||||
|
||||
// Serializes the dataset into a `GraphDef`, which has two uses:
|
||||
//
|
||||
// 1) To perform static input pipeline optimizations, tf.data serializes the
|
||||
// dataset graph, applies graph rewrites, and then deserializes the graph.
|
||||
// If a subclass of `DatasetBase` does not implement this method, then it will
|
||||
// be excluded from static optimizations (and so will any upstream datasets).
|
||||
//
|
||||
// 2) To save the dataset so that it can restore at a later point (possibly in
|
||||
// different environment). If a subclass of `DatasetBase` does not implement
|
||||
// this method, then this migration will not be possible.
|
||||
virtual Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
DatasetGraphDefBuilder* b,
|
||||
Node** node) const = 0;
|
||||
@ -802,10 +805,7 @@ class DatasetBaseIterator : public IteratorBase {
|
||||
bool* end_of_sequence) final;
|
||||
|
||||
Status Save(SerializationContext* ctx, IteratorStateWriter* writer) final {
|
||||
if (params_.dataset->IsStateful()) {
|
||||
return errors::FailedPrecondition(
|
||||
"Saving iterator that depends on external state is not supported.");
|
||||
}
|
||||
TF_RETURN_IF_ERROR(params_.dataset->CheckExternalState());
|
||||
return IteratorBase::Save(ctx, writer);
|
||||
}
|
||||
|
||||
|
@ -101,7 +101,9 @@ class BatchDatasetOp::Dataset : public DatasetBase {
|
||||
return n / batch_size_ + (n % batch_size_ == 0 || drop_remainder_ ? 0 : 1);
|
||||
}
|
||||
|
||||
bool IsStateful() const override { return input_->IsStateful(); }
|
||||
Status CheckExternalState() const override {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
|
@ -449,46 +449,6 @@ TEST_P(ParameterizedBatchDatasetOpTest, Cardinality) {
|
||||
EXPECT_EQ(batch_dataset->Cardinality(), test_case.expected_cardinality);
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedBatchDatasetOpTest, DatasetSave) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TestCase test_case = GetParam();
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
std::unique_ptr<OpKernel> batch_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateBatchDatasetOpKernel(
|
||||
test_case.parallel_copy, test_case.expected_output_dtypes,
|
||||
test_case.expected_output_shapes, &batch_dataset_kernel));
|
||||
|
||||
DatasetBase* range_dataset;
|
||||
TF_ASSERT_OK(CreateRangeDataset<int64>(
|
||||
test_case.range_dataset_param.start, test_case.range_dataset_param.end,
|
||||
test_case.range_dataset_param.step, "range", &range_dataset));
|
||||
Tensor range_dataset_tensor(DT_VARIANT, TensorShape({}));
|
||||
TF_ASSERT_OK(
|
||||
StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor));
|
||||
|
||||
Tensor batch_size = test_case.batch_size;
|
||||
Tensor drop_remainder = test_case.drop_remainder;
|
||||
gtl::InlinedVector<TensorValue, 4> inputs{TensorValue(&range_dataset_tensor),
|
||||
TensorValue(&batch_size),
|
||||
TensorValue(&drop_remainder)};
|
||||
std::unique_ptr<OpKernelContext> batch_dataset_context;
|
||||
TF_ASSERT_OK(CreateBatchDatasetContext(batch_dataset_kernel.get(), &inputs,
|
||||
&batch_dataset_context));
|
||||
DatasetBase* batch_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(batch_dataset_kernel.get(),
|
||||
batch_dataset_context.get(), &batch_dataset));
|
||||
core::ScopedUnref scoped_unref_batch_dataset(batch_dataset);
|
||||
|
||||
std::unique_ptr<SerializationContext> serialization_context;
|
||||
TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
|
||||
VariantTensorData data;
|
||||
VariantTensorDataWriter writer(&data);
|
||||
TF_ASSERT_OK(batch_dataset->Save(serialization_context.get(), &writer));
|
||||
TF_ASSERT_OK(writer.Flush());
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedBatchDatasetOpTest, IteratorOutputDtypes) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TestCase test_case = GetParam();
|
||||
|
@ -101,7 +101,9 @@ class CacheDatasetOp::FileDataset : public DatasetBase {
|
||||
|
||||
int64 Cardinality() const override { return input_->Cardinality(); }
|
||||
|
||||
bool IsStateful() const override { return input_->IsStateful(); }
|
||||
Status CheckExternalState() const override {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
@ -669,7 +671,9 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase {
|
||||
|
||||
int64 Cardinality() const override { return input_->Cardinality(); }
|
||||
|
||||
bool IsStateful() const override { return input_->IsStateful(); }
|
||||
Status CheckExternalState() const override {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
@ -971,7 +975,10 @@ class CacheDatasetOp::MemoryDatasetV2 : public CacheDatasetOp::MemoryDataset {
|
||||
MemoryCache* cache, const Tensor& resource_handle)
|
||||
: MemoryDataset(ctx, input, cache), resource_handle_(resource_handle) {}
|
||||
|
||||
bool IsStateful() const override { return true; }
|
||||
Status CheckExternalState() const override {
|
||||
return errors::FailedPrecondition(DebugString(),
|
||||
" depends on memory cache resource.");
|
||||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
|
@ -353,39 +353,6 @@ TEST_P(ParameterizedCacheDatasetOpTest, Cardinality) {
|
||||
EXPECT_EQ(cache_dataset->Cardinality(), test_case.expected_cardinality);
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedCacheDatasetOpTest, DatasetSave) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TestCase test_case = GetParam();
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
std::unique_ptr<OpKernel> cache_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateCacheDatasetOpKernel(test_case.expected_output_dtypes,
|
||||
test_case.expected_output_shapes,
|
||||
&cache_dataset_kernel));
|
||||
Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({}));
|
||||
std::vector<Tensor> inputs_for_tensor_slice_dataset = test_case.input_tensors;
|
||||
TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset,
|
||||
&tensor_slice_dataset_tensor));
|
||||
Tensor file_name = CreateTensor<string>(TensorShape{}, {test_case.file_name});
|
||||
gtl::InlinedVector<TensorValue, 4> inputs(
|
||||
{TensorValue(&tensor_slice_dataset_tensor), TensorValue(&file_name)});
|
||||
std::unique_ptr<OpKernelContext> cache_dataset_context;
|
||||
TF_ASSERT_OK(CreateCacheDatasetContext(cache_dataset_kernel.get(), &inputs,
|
||||
&cache_dataset_context));
|
||||
DatasetBase* cache_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(cache_dataset_kernel.get(),
|
||||
cache_dataset_context.get(), &cache_dataset));
|
||||
core::ScopedUnref scoped_unref(cache_dataset);
|
||||
|
||||
std::unique_ptr<SerializationContext> serialization_context;
|
||||
TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
|
||||
VariantTensorData data;
|
||||
VariantTensorDataWriter writer(&data);
|
||||
TF_ASSERT_OK(cache_dataset->Save(serialization_context.get(), &writer));
|
||||
TF_ASSERT_OK(writer.Flush());
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedCacheDatasetOpTest, IteratorOutputShapes) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TestCase test_case = GetParam();
|
||||
|
@ -210,17 +210,19 @@ Status CreateFunctionLibraryDefinition(
|
||||
return (*result)->CopyFunctionDefFrom(func_name, *lib_def);
|
||||
}
|
||||
|
||||
bool IsNodeStateful(const FunctionLibraryDefinition& library,
|
||||
const NodeDef& node);
|
||||
Status IsNodeStateful(const FunctionLibraryDefinition& library,
|
||||
const NodeDef& node);
|
||||
|
||||
bool IsFunctionStateful(const FunctionLibraryDefinition& library,
|
||||
const FunctionDef& function_def) {
|
||||
if (!function_def.signature().is_stateful()) return false;
|
||||
Status IsFunctionStateful(const FunctionLibraryDefinition& library,
|
||||
const FunctionDef& function_def) {
|
||||
if (!function_def.signature().is_stateful()) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
for (const NodeDef& node_def : function_def.node_def()) {
|
||||
if (IsNodeStateful(library, node_def)) return true;
|
||||
TF_RETURN_IF_ERROR(IsNodeStateful(library, node_def));
|
||||
}
|
||||
return false;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Returns whether an op has been whitelisted as stateless. Uses a heuristic to
|
||||
@ -228,27 +230,23 @@ bool IsFunctionStateful(const FunctionLibraryDefinition& library,
|
||||
// b/65524810. Also looks up the `op_def->name` in the global
|
||||
// `WhitelistedStatefulOpRegistry`.
|
||||
bool IsOpWhitelisted(const OpDef* op_def) {
|
||||
return ((absl::EndsWith(op_def->name(), "Dataset") ||
|
||||
absl::EndsWith(op_def->name(), "DatasetV2")) &&
|
||||
op_def->output_arg_size() == 1 &&
|
||||
op_def->output_arg(0).type() == DT_VARIANT) ||
|
||||
return (op_def->output_arg_size() == 1 &&
|
||||
op_def->output_arg(0).type() == DT_VARIANT &&
|
||||
(absl::EndsWith(op_def->name(), "Dataset") ||
|
||||
absl::EndsWith(op_def->name(), "DatasetV2"))) ||
|
||||
WhitelistedStatefulOpRegistry::Global()->Contains(op_def->name());
|
||||
}
|
||||
|
||||
bool IsNodeStateful(const FunctionLibraryDefinition& library,
|
||||
const NodeDef& node) {
|
||||
Status IsNodeStateful(const FunctionLibraryDefinition& library,
|
||||
const NodeDef& node) {
|
||||
const OpDef* op_def;
|
||||
Status s = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
|
||||
if (!s.ok()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (IsOpWhitelisted(op_def)) return false;
|
||||
|
||||
if (!op_def->is_stateful()) return false;
|
||||
|
||||
if (op_def->name() == "Assert") {
|
||||
return false;
|
||||
// TODO(jsimsa): Fix C++ unit tests so that we do not have to ignore
|
||||
// `LookUpOpDef` errors here.
|
||||
if (!OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok() ||
|
||||
IsOpWhitelisted(op_def) || !op_def->is_stateful() ||
|
||||
op_def->name() == "Assert") {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (op_def->name() == "If") {
|
||||
@ -256,10 +254,13 @@ bool IsNodeStateful(const FunctionLibraryDefinition& library,
|
||||
library.Find(node.attr().at("then_branch").func().name());
|
||||
const FunctionDef* else_func =
|
||||
library.Find(node.attr().at("else_branch").func().name());
|
||||
if ((then_func != nullptr && !IsFunctionStateful(library, *then_func)) &&
|
||||
(else_func != nullptr && !IsFunctionStateful(library, *else_func))) {
|
||||
return false;
|
||||
if (then_func != nullptr) {
|
||||
TF_RETURN_IF_ERROR(IsFunctionStateful(library, *then_func));
|
||||
}
|
||||
if (else_func != nullptr) {
|
||||
TF_RETURN_IF_ERROR(IsFunctionStateful(library, *else_func));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (op_def->name() == "While") {
|
||||
@ -267,12 +268,16 @@ bool IsNodeStateful(const FunctionLibraryDefinition& library,
|
||||
library.Find(node.attr().at("cond").func().name());
|
||||
const FunctionDef* body_func =
|
||||
library.Find(node.attr().at("body").func().name());
|
||||
if ((cond_func != nullptr && !IsFunctionStateful(library, *cond_func)) &&
|
||||
(body_func != nullptr && !IsFunctionStateful(library, *body_func))) {
|
||||
return false;
|
||||
if (cond_func != nullptr) {
|
||||
TF_RETURN_IF_ERROR(IsFunctionStateful(library, *cond_func));
|
||||
}
|
||||
if (body_func != nullptr) {
|
||||
TF_RETURN_IF_ERROR(IsFunctionStateful(library, *body_func));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
return true;
|
||||
|
||||
return errors::FailedPrecondition(op_def->name(), " is stateful.");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@ -482,13 +487,14 @@ Status CapturedFunction::Instantiate(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool CapturedFunction::IsStateful() const {
|
||||
bool CapturedFunction::IsStateful() const { return !CheckExternalState().ok(); }
|
||||
|
||||
Status CapturedFunction::CheckExternalState() const {
|
||||
for (const auto& name : lib_def()->ListFunctionNames()) {
|
||||
if (IsFunctionStateful(*lib_def(), *(lib_def()->Find(name)))) {
|
||||
return true;
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
IsFunctionStateful(*lib_def(), *(lib_def()->Find(name))));
|
||||
}
|
||||
return false;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -205,8 +205,14 @@ class CapturedFunction {
|
||||
instantiated_captured_function);
|
||||
|
||||
// Determines whether the captured function is stateful.
|
||||
//
|
||||
// TODO(jsimsa): Remove this method once all users of `CapturedFunction`
|
||||
// migrate to `CheckExternalState`.
|
||||
bool IsStateful() const;
|
||||
|
||||
// Determines whether the captured function is stateful.
|
||||
Status CheckExternalState() const;
|
||||
|
||||
// Returns the additional captured inputs that will be passed to the function.
|
||||
const std::vector<Tensor>& captured_inputs() const {
|
||||
return captured_inputs_;
|
||||
|
@ -85,8 +85,9 @@ class ConcatenateDatasetOp::Dataset : public DatasetBase {
|
||||
return n1 + n2;
|
||||
}
|
||||
|
||||
bool IsStateful() const override {
|
||||
return input_->IsStateful() || to_concatenate_->IsStateful();
|
||||
Status CheckExternalState() const override {
|
||||
TF_RETURN_IF_ERROR(input_->CheckExternalState());
|
||||
return to_concatenate_->CheckExternalState();
|
||||
}
|
||||
|
||||
protected:
|
||||
|
@ -357,39 +357,6 @@ TEST_P(ParameterizedConcatenateDatasetOpTest, Cardinality) {
|
||||
EXPECT_EQ(concatenate_dataset->Cardinality(), test_case.expected_cardinality);
|
||||
}
|
||||
|
||||
TEST_F(ConcatenateDatasetOpTest, DatasetSave) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
const TestCase &test_case = SameShapeTestCase();
|
||||
std::vector<Tensor> tensor_slice_dataset_tensors;
|
||||
TF_ASSERT_OK(CreateTensorSliceDatasetTensors(test_case.input_tensors,
|
||||
&tensor_slice_dataset_tensors));
|
||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||
for (auto &tensor : tensor_slice_dataset_tensors) {
|
||||
inputs.emplace_back(&tensor);
|
||||
}
|
||||
std::unique_ptr<OpKernel> dataset_kernel;
|
||||
TF_ASSERT_OK(CreateConcatenateDatasetKernel(test_case.expected_output_dtypes,
|
||||
test_case.expected_output_shapes,
|
||||
&dataset_kernel));
|
||||
std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
|
||||
TF_ASSERT_OK(CreateConcatenateDatasetContext(dataset_kernel.get(), &inputs,
|
||||
&dataset_kernel_ctx));
|
||||
DatasetBase *concatenate_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(),
|
||||
&concatenate_dataset));
|
||||
core::ScopedUnref scoped_unref(concatenate_dataset);
|
||||
|
||||
std::unique_ptr<SerializationContext> serialization_ctx;
|
||||
TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
|
||||
VariantTensorData data;
|
||||
VariantTensorDataWriter writer(&data);
|
||||
TF_ASSERT_OK(concatenate_dataset->Save(serialization_ctx.get(), &writer));
|
||||
TF_ASSERT_OK(writer.Flush());
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedConcatenateDatasetOpTest, IteratorOutputDtypes) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
|
@ -593,16 +593,6 @@ Status DatasetOpsTestBase::CheckDatasetCardinality(const DatasetBase& dataset,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DatasetOpsTestBase::CheckDatasetSave(const DatasetBase& dataset) {
|
||||
std::unique_ptr<SerializationContext> serialization_context;
|
||||
TF_RETURN_IF_ERROR(CreateSerializationContext(&serialization_context));
|
||||
VariantTensorData data;
|
||||
VariantTensorDataWriter writer(&data);
|
||||
TF_EXPECT_OK(dataset.Save(serialization_context.get(), &writer));
|
||||
TF_RETURN_IF_ERROR(writer.Flush());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DatasetOpsTestBase::CheckDatasetIsStateful(const DatasetBase& dataset,
|
||||
bool expected_stateful) {
|
||||
EXPECT_EQ(dataset.IsStateful(), expected_stateful);
|
||||
|
@ -317,9 +317,6 @@ class DatasetOpsTestBase : public ::testing::Test {
|
||||
Status CheckDatasetCardinality(const DatasetBase& dataset,
|
||||
int64 expected_cardinality);
|
||||
|
||||
// Checks `DatasetBase::Save()`.
|
||||
Status CheckDatasetSave(const DatasetBase& dataset);
|
||||
|
||||
// Checks `DatasetBase::IsStateful()`.
|
||||
Status CheckDatasetIsStateful(const DatasetBase& dataset,
|
||||
bool expected_stateful);
|
||||
|
@ -375,13 +375,16 @@ uint64 HashSubgraphFunctionImpl(
|
||||
Status AsGraphDef(OpKernelContext* ctx, const DatasetBase* dataset,
|
||||
SerializationContext&& serialization_ctx,
|
||||
GraphDef* graph_def) {
|
||||
if (serialization_ctx.check_external_state()) {
|
||||
TF_RETURN_IF_ERROR(dataset->CheckExternalState());
|
||||
}
|
||||
GraphDefBuilder b;
|
||||
DatasetBase::DatasetGraphDefBuilder db(&b);
|
||||
Node* output_node = nullptr;
|
||||
TF_RETURN_IF_ERROR(
|
||||
db.AddInputDataset(&serialization_ctx, dataset, &output_node));
|
||||
// Insert a purely symbolic _Retval node to indicate to consumers which Tensor
|
||||
// represents this Dataset.
|
||||
// Insert a purely symbolic _Retval node to indicate to consumers which node
|
||||
// represents `dataset`.
|
||||
ops::UnaryOp("_Retval", output_node,
|
||||
b.opts()
|
||||
.WithName("dataset")
|
||||
@ -415,7 +418,9 @@ Status RewriteDataset(OpKernelContext* ctx, const DatasetBase* input,
|
||||
SerializationContext::Params params;
|
||||
std::vector<std::pair<string, Tensor>> input_list;
|
||||
params.input_list = &input_list;
|
||||
params.optimization_only = true;
|
||||
params.check_external_state = false;
|
||||
params.fail_if_unimplemented = false;
|
||||
params.serialize_data_tensors = false;
|
||||
SerializationContext serialization_ctx(params);
|
||||
GraphDef graph_def;
|
||||
TF_RETURN_IF_ERROR(
|
||||
|
@ -63,7 +63,9 @@ class AssertNextDatasetOp::Dataset : public DatasetBase {
|
||||
|
||||
int64 Cardinality() const override { return input_->Cardinality(); }
|
||||
|
||||
bool IsStateful() const override { return input_->IsStateful(); }
|
||||
Status CheckExternalState() const override {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
|
@ -376,43 +376,6 @@ TEST_P(ParameterizedAssertNextDatasetOpTest, Cardinality) {
|
||||
EXPECT_EQ(assert_next_dataset->Cardinality(), test_case.expected_cardinality);
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedAssertNextDatasetOpTest, DatasetSave) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TestCase test_case = GetParam();
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
Tensor range_and_take_dataset_tensor;
|
||||
TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
|
||||
test_case.take_dataset_params,
|
||||
&range_and_take_dataset_tensor));
|
||||
|
||||
std::unique_ptr<OpKernel> assert_next_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateAssertNextDatasetOpKernel(test_case.expected_output_dtypes,
|
||||
test_case.expected_output_shapes,
|
||||
&assert_next_dataset_kernel));
|
||||
Tensor transformations = test_case.transformations;
|
||||
gtl::InlinedVector<TensorValue, 4> inputs(
|
||||
{TensorValue(&range_and_take_dataset_tensor),
|
||||
TensorValue(&transformations)});
|
||||
std::unique_ptr<OpKernelContext> assert_next_dataset_context;
|
||||
TF_ASSERT_OK(CreateAssertNextDatasetContext(
|
||||
assert_next_dataset_kernel.get(), &inputs, &assert_next_dataset_context));
|
||||
|
||||
DatasetBase* assert_next_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(assert_next_dataset_kernel.get(),
|
||||
assert_next_dataset_context.get(),
|
||||
&assert_next_dataset));
|
||||
core::ScopedUnref scoped_unref(assert_next_dataset);
|
||||
|
||||
std::unique_ptr<SerializationContext> serialization_context;
|
||||
TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
|
||||
VariantTensorData data;
|
||||
VariantTensorDataWriter writer(&data);
|
||||
TF_ASSERT_OK(assert_next_dataset->Save(serialization_context.get(), &writer));
|
||||
TF_ASSERT_OK(writer.Flush());
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedAssertNextDatasetOpTest, IteratorOutputDtypes) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TestCase test_case = GetParam();
|
||||
|
@ -242,13 +242,11 @@ class ChooseFastestBranchDatasetOp : public UnaryDatasetOpKernel {
|
||||
return static_cast<double>(n) * ratio_numerator_ / ratio_denominator_;
|
||||
}
|
||||
|
||||
bool IsStateful() const override {
|
||||
Status CheckExternalState() const override {
|
||||
for (const auto& captured_func : captured_funcs_) {
|
||||
if (captured_func->IsStateful()) {
|
||||
return true;
|
||||
}
|
||||
TF_RETURN_IF_ERROR(captured_func->CheckExternalState());
|
||||
}
|
||||
return input_->IsStateful();
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
||||
protected:
|
||||
|
@ -158,13 +158,11 @@ class ChooseFastestDatasetOp : public DatasetOpKernel {
|
||||
|
||||
int64 Cardinality() const override { return cardinality_; }
|
||||
|
||||
bool IsStateful() const override {
|
||||
Status CheckExternalState() const override {
|
||||
for (const auto& input : inputs_) {
|
||||
if (input->IsStateful()) {
|
||||
return true;
|
||||
}
|
||||
TF_RETURN_IF_ERROR(input->CheckExternalState());
|
||||
}
|
||||
return false;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
protected:
|
||||
|
@ -170,6 +170,8 @@ class CSVDatasetOp : public DatasetOpKernel {
|
||||
|
||||
string DebugString() const override { return "CSVDatasetOp::Dataset"; }
|
||||
|
||||
Status CheckExternalState() const override { return Status::OK(); }
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
DatasetGraphDefBuilder* b,
|
||||
|
@ -120,7 +120,9 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel {
|
||||
return n / batch_size_ + (n % batch_size_ == 0 ? 0 : 1);
|
||||
}
|
||||
|
||||
bool IsStateful() const override { return input_->IsStateful(); }
|
||||
Status CheckExternalState() const override {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
|
@ -107,13 +107,11 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel {
|
||||
return strings::StrCat("DirectedInterleaveDatasetOp::Dataset");
|
||||
}
|
||||
|
||||
bool IsStateful() const override {
|
||||
Status CheckExternalState() const override {
|
||||
for (const auto& input : data_inputs_) {
|
||||
if (input->IsStateful()) {
|
||||
return true;
|
||||
}
|
||||
TF_RETURN_IF_ERROR(input->CheckExternalState());
|
||||
}
|
||||
return selector_input_->IsStateful();
|
||||
return selector_input_->CheckExternalState();
|
||||
}
|
||||
|
||||
protected:
|
||||
|
@ -112,11 +112,12 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
|
||||
return "GroupByReducerDatasetOp::Dataset";
|
||||
}
|
||||
|
||||
bool IsStateful() const override {
|
||||
return captured_key_func_->IsStateful() ||
|
||||
captured_init_func_->IsStateful() ||
|
||||
captured_reduce_func_->IsStateful() ||
|
||||
captured_finalize_func_->IsStateful() || input_->IsStateful();
|
||||
Status CheckExternalState() const override {
|
||||
TF_RETURN_IF_ERROR(captured_key_func_->CheckExternalState());
|
||||
TF_RETURN_IF_ERROR(captured_init_func_->CheckExternalState());
|
||||
TF_RETURN_IF_ERROR(captured_reduce_func_->CheckExternalState());
|
||||
TF_RETURN_IF_ERROR(captured_finalize_func_->CheckExternalState());
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
||||
protected:
|
||||
|
@ -108,10 +108,11 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
|
||||
return "GroupByWindowDatasetOp::Dataset";
|
||||
}
|
||||
|
||||
bool IsStateful() const override {
|
||||
return captured_key_func_->IsStateful() ||
|
||||
captured_reduce_func_->IsStateful() ||
|
||||
captured_window_size_func_->IsStateful() || input_->IsStateful();
|
||||
Status CheckExternalState() const override {
|
||||
TF_RETURN_IF_ERROR(captured_key_func_->CheckExternalState());
|
||||
TF_RETURN_IF_ERROR(captured_reduce_func_->CheckExternalState());
|
||||
TF_RETURN_IF_ERROR(captured_window_size_func_->CheckExternalState());
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
||||
protected:
|
||||
|
@ -60,7 +60,9 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel {
|
||||
|
||||
int64 Cardinality() const override { return input_->Cardinality(); }
|
||||
|
||||
bool IsStateful() const override { return input_->IsStateful(); }
|
||||
Status CheckExternalState() const override {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
|
@ -70,6 +70,8 @@ class LMDBDatasetOp : public DatasetOpKernel {
|
||||
|
||||
string DebugString() const override { return "LMDBDatasetOp::Dataset"; }
|
||||
|
||||
Status CheckExternalState() const override { return Status::OK(); }
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
DatasetGraphDefBuilder* b,
|
||||
|
@ -143,8 +143,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
|
||||
(n % batch_size_ == 0 || drop_remainder_ ? 0 : 1);
|
||||
}
|
||||
|
||||
bool IsStateful() const override {
|
||||
return captured_func_->IsStateful() || input_->IsStateful();
|
||||
Status CheckExternalState() const override {
|
||||
TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
||||
protected:
|
||||
|
@ -81,6 +81,8 @@ class MatchingFilesDatasetOp : public DatasetOpKernel {
|
||||
return "MatchingFilesDatasetOp::Dataset";
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override { return Status::OK(); }
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
DatasetGraphDefBuilder* b,
|
||||
|
@ -69,13 +69,16 @@ class NonSerializableDatasetOp : public UnaryDatasetOpKernel {
|
||||
return "NonSerializableDatasetOp::Dataset";
|
||||
}
|
||||
|
||||
bool IsStateful() const override { return input_->IsStateful(); }
|
||||
Status CheckExternalState() const override {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
DatasetGraphDefBuilder* b,
|
||||
Node** output) const override {
|
||||
return errors::Unimplemented(DebugString(), "::AsGraphDefInternal");
|
||||
return errors::Unimplemented(DebugString(),
|
||||
" does not support serialization.");
|
||||
}
|
||||
|
||||
int64 Cardinality() const override { return input_->Cardinality(); }
|
||||
|
@ -121,8 +121,9 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
return name_utils::DatasetDebugString(kDatasetType);
|
||||
}
|
||||
|
||||
bool IsStateful() const override {
|
||||
return captured_func_->IsStateful() || input_->IsStateful();
|
||||
Status CheckExternalState() const override {
|
||||
TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
||||
protected:
|
||||
|
@ -569,46 +569,6 @@ TEST_P(ParameterizedParallelInterleaveDatasetOpTest, Cardinality) {
|
||||
test_case.expected_cardinality);
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedParallelInterleaveDatasetOpTest, DatasetSave) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TestCase test_case = GetParam();
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num));
|
||||
|
||||
std::unique_ptr<OpKernel> parallel_interleave_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateParallelInterleaveDatasetKernel(
|
||||
test_case.func, test_case.expected_output_dtypes,
|
||||
test_case.expected_output_shapes, ¶llel_interleave_dataset_kernel));
|
||||
|
||||
Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({}));
|
||||
std::vector<Tensor> inputs_for_tensor_slice_dataset = test_case.input_tensors;
|
||||
TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset,
|
||||
&tensor_slice_dataset_tensor));
|
||||
gtl::InlinedVector<TensorValue, 4> inputs(
|
||||
{TensorValue(&tensor_slice_dataset_tensor),
|
||||
TensorValue(&test_case.cycle_length),
|
||||
TensorValue(&test_case.block_length), TensorValue(&test_case.sloppy),
|
||||
TensorValue(&test_case.buffer_output_elements),
|
||||
TensorValue(&test_case.prefetch_input_elements)});
|
||||
std::unique_ptr<OpKernelContext> parallel_interleave_dataset_context;
|
||||
TF_ASSERT_OK(CreateParallelInterleaveDatasetContext(
|
||||
parallel_interleave_dataset_kernel.get(), &inputs,
|
||||
¶llel_interleave_dataset_context));
|
||||
DatasetBase* parallel_interleave_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(parallel_interleave_dataset_kernel.get(),
|
||||
parallel_interleave_dataset_context.get(),
|
||||
¶llel_interleave_dataset));
|
||||
core::ScopedUnref scoped_unref_dataset(parallel_interleave_dataset);
|
||||
|
||||
std::unique_ptr<SerializationContext> serialization_ctx;
|
||||
TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
|
||||
VariantTensorData data;
|
||||
VariantTensorDataWriter writer(&data);
|
||||
TF_ASSERT_OK(
|
||||
parallel_interleave_dataset->Save(serialization_ctx.get(), &writer));
|
||||
TF_ASSERT_OK(writer.Flush());
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedParallelInterleaveDatasetOpTest, IteratorOutputDtypes) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TestCase test_case = GetParam();
|
||||
|
@ -206,7 +206,9 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
|
||||
|
||||
int64 Cardinality() const override { return input_->Cardinality(); }
|
||||
|
||||
bool IsStateful() const override { return input_->IsStateful(); }
|
||||
Status CheckExternalState() const override {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
|
@ -76,6 +76,8 @@ class RandomDatasetOp : public DatasetOpKernel {
|
||||
|
||||
int64 Cardinality() const override { return kInfiniteCardinality; }
|
||||
|
||||
Status CheckExternalState() const override { return Status::OK(); }
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
DatasetGraphDefBuilder* b,
|
||||
|
@ -79,7 +79,9 @@ class SamplingDatasetOp : public UnaryDatasetOpKernel {
|
||||
|
||||
string DebugString() const override { return "SamplingDatasetOp::Dataset"; }
|
||||
|
||||
bool IsStateful() const override { return input_->IsStateful(); }
|
||||
Status CheckExternalState() const override {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
|
@ -102,8 +102,9 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
|
||||
|
||||
int64 Cardinality() const override { return input_->Cardinality(); }
|
||||
|
||||
bool IsStateful() const override {
|
||||
return captured_func_->IsStateful() || input_->IsStateful();
|
||||
Status CheckExternalState() const override {
|
||||
TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
||||
protected:
|
||||
|
@ -138,7 +138,9 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
|
||||
|
||||
int64 Cardinality() const override { return input_->Cardinality(); }
|
||||
|
||||
bool IsStateful() const override { return input_->IsStateful(); }
|
||||
Status CheckExternalState() const override {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
|
@ -67,7 +67,9 @@ class SleepDatasetOp : public UnaryDatasetOpKernel {
|
||||
|
||||
int64 Cardinality() const override { return input_->Cardinality(); }
|
||||
|
||||
bool IsStateful() const override { return input_->IsStateful(); }
|
||||
Status CheckExternalState() const override {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
|
@ -108,7 +108,9 @@ class SlidingWindowDatasetOp : public UnaryDatasetOpKernel {
|
||||
return n / window_shift_;
|
||||
}
|
||||
|
||||
bool IsStateful() const override { return input_->IsStateful(); }
|
||||
Status CheckExternalState() const override {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
|
@ -314,7 +314,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
SerializationContext::Params params;
|
||||
std::vector<std::pair<string, Tensor>> input_list;
|
||||
params.input_list = &input_list;
|
||||
params.optimization_only = true;
|
||||
params.check_external_state = false;
|
||||
|
||||
GraphDef graph_def;
|
||||
OP_REQUIRES_OK(
|
||||
@ -376,7 +376,9 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
|
||||
int64 Cardinality() const override { return input_->Cardinality(); }
|
||||
|
||||
bool IsStateful() const override { return input_->IsStateful(); }
|
||||
Status CheckExternalState() const override {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
|
@ -103,6 +103,8 @@ class SqlDatasetOp : public DatasetOpKernel {
|
||||
|
||||
string DebugString() const override { return "SqlDatasetOp::Dataset"; }
|
||||
|
||||
Status CheckExternalState() const override { return Status::OK(); }
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
DatasetGraphDefBuilder* b,
|
||||
|
@ -78,6 +78,10 @@ class LatencyStatsDatasetOp : public UnaryDatasetOpKernel {
|
||||
|
||||
int64 Cardinality() const override { return input_->Cardinality(); }
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
DatasetGraphDefBuilder* b,
|
||||
@ -189,6 +193,10 @@ class BytesProducedStatsDatasetOp : public UnaryDatasetOpKernel {
|
||||
|
||||
int64 Cardinality() const override { return input_->Cardinality(); }
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
DatasetGraphDefBuilder* b,
|
||||
|
@ -84,8 +84,9 @@ class TakeWhileDatasetOp : public UnaryDatasetOpKernel {
|
||||
|
||||
int64 Cardinality() const override { return kUnknownCardinality; }
|
||||
|
||||
bool IsStateful() const override {
|
||||
return captured_func_->IsStateful() || input_->IsStateful();
|
||||
Status CheckExternalState() const override {
|
||||
TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
||||
protected:
|
||||
|
@ -172,7 +172,9 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel {
|
||||
|
||||
int64 Cardinality() const override { return input_->Cardinality(); }
|
||||
|
||||
bool IsStateful() const override { return input_->IsStateful(); }
|
||||
Status CheckExternalState() const override {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
@ -281,7 +283,9 @@ class MaxIntraOpParallelismDatasetOp : public UnaryDatasetOpKernel {
|
||||
|
||||
int64 Cardinality() const override { return input_->Cardinality(); }
|
||||
|
||||
bool IsStateful() const override { return input_->IsStateful(); }
|
||||
Status CheckExternalState() const override {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
@ -383,7 +387,9 @@ class PrivateThreadPoolDatasetOp : public UnaryDatasetOpKernel {
|
||||
|
||||
int64 Cardinality() const override { return input_->Cardinality(); }
|
||||
|
||||
bool IsStateful() const override { return input_->IsStateful(); }
|
||||
Status CheckExternalState() const override {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
|
@ -69,7 +69,9 @@ class UnbatchDatasetOp : public UnaryDatasetOpKernel {
|
||||
|
||||
string DebugString() const override { return "UnbatchDatasetOp::Dataset"; }
|
||||
|
||||
bool IsStateful() const override { return input_->IsStateful(); }
|
||||
Status CheckExternalState() const override {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
|
@ -72,7 +72,9 @@ class UniqueDatasetOp : public UnaryDatasetOpKernel {
|
||||
return strings::StrCat("UniqueDatasetOp::Dataset");
|
||||
}
|
||||
|
||||
bool IsStateful() const override { return input_->IsStateful(); }
|
||||
Status CheckExternalState() const override {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
|
@ -75,8 +75,9 @@ class FilterDatasetOp::Dataset : public DatasetBase {
|
||||
return name_utils::DatasetDebugString(kDatasetType);
|
||||
}
|
||||
|
||||
bool IsStateful() const override {
|
||||
return captured_func_->IsStateful() || input_->IsStateful();
|
||||
Status CheckExternalState() const override {
|
||||
TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
||||
protected:
|
||||
|
@ -346,39 +346,6 @@ TEST_P(ParameterizedFilterDatasetOpTest, Cardinality) {
|
||||
EXPECT_EQ(filter_dataset->Cardinality(), test_case.expected_cardinality);
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedFilterDatasetOpTest, DatasetSave) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
const TestCase &test_case = GetParam();
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num));
|
||||
|
||||
std::unique_ptr<OpKernel> filter_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateFilterDatasetKernel(
|
||||
test_case.func, test_case.expected_output_dtypes,
|
||||
test_case.expected_output_shapes, &filter_dataset_kernel));
|
||||
|
||||
Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({}));
|
||||
std::vector<Tensor> inputs_for_tensor_slice_dataset = test_case.input_tensors;
|
||||
TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset,
|
||||
&tensor_slice_dataset_tensor));
|
||||
gtl::InlinedVector<TensorValue, 4> inputs(
|
||||
{TensorValue(&tensor_slice_dataset_tensor)});
|
||||
std::unique_ptr<OpKernelContext> filter_dataset_context;
|
||||
TF_ASSERT_OK(CreateFilterDatasetContext(filter_dataset_kernel.get(), &inputs,
|
||||
&filter_dataset_context));
|
||||
DatasetBase *filter_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(filter_dataset_kernel.get(),
|
||||
filter_dataset_context.get(), &filter_dataset));
|
||||
core::ScopedUnref scoped_unref(filter_dataset);
|
||||
|
||||
std::unique_ptr<SerializationContext> serialization_ctx;
|
||||
TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
|
||||
VariantTensorData data;
|
||||
VariantTensorDataWriter writer(&data);
|
||||
TF_ASSERT_OK(filter_dataset->Save(serialization_ctx.get(), &writer));
|
||||
TF_ASSERT_OK(writer.Flush());
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedFilterDatasetOpTest, IteratorOutputDtypes) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
const TestCase &test_case = GetParam();
|
||||
|
@ -93,6 +93,8 @@ class FixedLengthRecordDatasetOp::Dataset : public DatasetBase {
|
||||
return name_utils::DatasetDebugString(kDatasetType, params);
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override { return Status::OK(); }
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
DatasetGraphDefBuilder* b,
|
||||
|
@ -452,56 +452,6 @@ TEST_P(ParameterizedFixedLengthRecordDatasetOpTest, Cardinality) {
|
||||
test_case.expected_cardinality);
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedFixedLengthRecordDatasetOpTest, DatasetSave) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TestCase test_case = GetParam();
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
TF_ASSERT_OK(CreateTestFiles(test_case));
|
||||
|
||||
std::unique_ptr<OpKernel> fixed_length_record_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateFixedLengthRecordDatasetOpKernel(
|
||||
&fixed_length_record_dataset_kernel));
|
||||
|
||||
int64 num_files = test_case.filenames.size();
|
||||
Tensor filenames =
|
||||
CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
|
||||
Tensor header_bytes =
|
||||
CreateTensor<int64>(TensorShape({}), {test_case.header_bytes});
|
||||
Tensor record_bytes =
|
||||
CreateTensor<int64>(TensorShape({}), {test_case.record_bytes});
|
||||
Tensor footer_bytes =
|
||||
CreateTensor<int64>(TensorShape({}), {test_case.footer_bytes});
|
||||
Tensor buffer_size =
|
||||
CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
|
||||
Tensor compression_type = CreateTensor<string>(
|
||||
TensorShape({}), {ToString(test_case.compression_type)});
|
||||
gtl::InlinedVector<TensorValue, 4> inputs{
|
||||
TensorValue(&filenames), TensorValue(&header_bytes),
|
||||
TensorValue(&record_bytes), TensorValue(&footer_bytes),
|
||||
TensorValue(&buffer_size), TensorValue(&compression_type),
|
||||
};
|
||||
std::unique_ptr<OpKernelContext> fixed_length_record_dataset_context;
|
||||
TF_ASSERT_OK(CreateFixedLengthRecordDatasetContext(
|
||||
fixed_length_record_dataset_kernel.get(), &inputs,
|
||||
&fixed_length_record_dataset_context));
|
||||
|
||||
DatasetBase* fixed_length_record_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(fixed_length_record_dataset_kernel.get(),
|
||||
fixed_length_record_dataset_context.get(),
|
||||
&fixed_length_record_dataset));
|
||||
core::ScopedUnref scoped_unref(fixed_length_record_dataset);
|
||||
|
||||
std::unique_ptr<SerializationContext> serialization_context;
|
||||
TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
|
||||
VariantTensorData data;
|
||||
VariantTensorDataWriter writer(&data);
|
||||
TF_ASSERT_OK(
|
||||
fixed_length_record_dataset->Save(serialization_context.get(), &writer));
|
||||
TF_ASSERT_OK(writer.Flush());
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedFixedLengthRecordDatasetOpTest, IteratorOutputDtypes) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TestCase test_case = GetParam();
|
||||
|
@ -75,8 +75,9 @@ class FlatMapDatasetOp::Dataset : public DatasetBase {
|
||||
return name_utils::DatasetDebugString(kDatasetType);
|
||||
}
|
||||
|
||||
bool IsStateful() const override {
|
||||
return captured_func_->IsStateful() || input_->IsStateful();
|
||||
Status CheckExternalState() const override {
|
||||
TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
||||
protected:
|
||||
|
@ -364,41 +364,6 @@ TEST_P(ParameterizedFlatMapDatasetOpTest, Cardinality) {
|
||||
EXPECT_EQ(flat_map_dataset->Cardinality(), test_case.expected_cardinality);
|
||||
}
|
||||
|
||||
TEST_F(FlatMapDatasetOpTest, DatasetSave) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
const TestCase &test_case = MakeTensorSliceDatasetFuncTestCase();
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num));
|
||||
|
||||
std::unique_ptr<OpKernel> flat_map_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateFlatMapDatasetKernel(
|
||||
test_case.func, test_case.expected_output_dtypes,
|
||||
test_case.expected_output_shapes, &flat_map_dataset_kernel));
|
||||
|
||||
Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({}));
|
||||
std::vector<Tensor> inputs_for_tensor_slice_dataset = test_case.input_tensors;
|
||||
TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset,
|
||||
&tensor_slice_dataset_tensor));
|
||||
|
||||
gtl::InlinedVector<TensorValue, 4> inputs(
|
||||
{TensorValue(&tensor_slice_dataset_tensor)});
|
||||
std::unique_ptr<OpKernelContext> flat_map_dataset_context;
|
||||
TF_ASSERT_OK(CreateFlatMapDatasetContext(flat_map_dataset_kernel.get(),
|
||||
&inputs, &flat_map_dataset_context));
|
||||
DatasetBase *flat_map_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(flat_map_dataset_kernel.get(),
|
||||
flat_map_dataset_context.get(),
|
||||
&flat_map_dataset));
|
||||
core::ScopedUnref scoped_unref(flat_map_dataset);
|
||||
|
||||
std::unique_ptr<SerializationContext> serialization_ctx;
|
||||
TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
|
||||
VariantTensorData data;
|
||||
VariantTensorDataWriter writer(&data);
|
||||
TF_ASSERT_OK(flat_map_dataset->Save(serialization_ctx.get(), &writer));
|
||||
TF_ASSERT_OK(writer.Flush());
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedFlatMapDatasetOpTest, IteratorOutputDtypes) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
const TestCase &test_case = GetParam();
|
||||
|
@ -74,17 +74,18 @@ class GeneratorDatasetOp::Dataset : public DatasetBase {
|
||||
return name_utils::DatasetDebugString(kDatasetType);
|
||||
}
|
||||
|
||||
bool IsStateful() const override {
|
||||
return init_func_->IsStateful() || next_func_->IsStateful() ||
|
||||
finalize_func_->IsStateful();
|
||||
Status CheckExternalState() const override {
|
||||
TF_RETURN_IF_ERROR(init_func_->CheckExternalState());
|
||||
TF_RETURN_IF_ERROR(next_func_->CheckExternalState());
|
||||
return finalize_func_->CheckExternalState();
|
||||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
DatasetGraphDefBuilder* b,
|
||||
Node** output) const override {
|
||||
return errors::Unimplemented("%s does not support serialization",
|
||||
DebugString());
|
||||
return errors::Unimplemented(DebugString(),
|
||||
" does not support serialization");
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -81,8 +81,9 @@ class InterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
return name_utils::DatasetDebugString(kDatasetType);
|
||||
}
|
||||
|
||||
bool IsStateful() const override {
|
||||
return captured_func_->IsStateful() || input_->IsStateful();
|
||||
Status CheckExternalState() const override {
|
||||
TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
||||
protected:
|
||||
|
@ -570,43 +570,6 @@ TEST_P(ParameterizedInterleaveDatasetOpTest, Cardinality) {
|
||||
EXPECT_EQ(interleave_dataset->Cardinality(), test_case.expected_cardinality);
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedInterleaveDatasetOpTest, DatasetSave) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
const TestCase &test_case = GetParam();
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num));
|
||||
|
||||
std::unique_ptr<OpKernel> interleave_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateInterleaveDatasetKernel(
|
||||
test_case.func, test_case.expected_output_dtypes,
|
||||
test_case.expected_output_shapes, &interleave_dataset_kernel));
|
||||
|
||||
Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({}));
|
||||
std::vector<Tensor> inputs_for_tensor_slice_dataset = test_case.input_tensors;
|
||||
TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset,
|
||||
&tensor_slice_dataset_tensor));
|
||||
Tensor cycle_length = test_case.cycle_length;
|
||||
Tensor block_length = test_case.block_length;
|
||||
gtl::InlinedVector<TensorValue, 4> inputs(
|
||||
{TensorValue(&tensor_slice_dataset_tensor), TensorValue(&cycle_length),
|
||||
TensorValue(&block_length)});
|
||||
std::unique_ptr<OpKernelContext> interleave_dataset_context;
|
||||
TF_ASSERT_OK(CreateInterleaveDatasetContext(
|
||||
interleave_dataset_kernel.get(), &inputs, &interleave_dataset_context));
|
||||
DatasetBase *interleave_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(interleave_dataset_kernel.get(),
|
||||
interleave_dataset_context.get(),
|
||||
&interleave_dataset));
|
||||
core::ScopedUnref scoped_unref(interleave_dataset);
|
||||
|
||||
std::unique_ptr<SerializationContext> serialization_ctx;
|
||||
TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
|
||||
VariantTensorData data;
|
||||
VariantTensorDataWriter writer(&data);
|
||||
TF_ASSERT_OK(interleave_dataset->Save(serialization_ctx.get(), &writer));
|
||||
TF_ASSERT_OK(writer.Flush());
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedInterleaveDatasetOpTest, IteratorOutputDtypes) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
const TestCase &test_case = GetParam();
|
||||
|
@ -126,8 +126,8 @@ Status IteratorResource::Restore(OpKernelContext* ctx,
|
||||
params.cancellation_manager,
|
||||
&deregister_fn));
|
||||
auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
|
||||
return captured_state->iterator->Restore(IteratorContext(std::move(params)),
|
||||
reader);
|
||||
IteratorContext iter_ctx(std::move(params));
|
||||
return captured_state->iterator->Restore(&iter_ctx, reader);
|
||||
}
|
||||
return errors::FailedPrecondition(
|
||||
"Restore() failed because the iterator has not been initialized. Ensure "
|
||||
|
@ -74,8 +74,9 @@ class MapDatasetOp::Dataset : public DatasetBase {
|
||||
|
||||
int64 Cardinality() const override { return input_->Cardinality(); }
|
||||
|
||||
bool IsStateful() const override {
|
||||
return captured_func_->IsStateful() || input_->IsStateful();
|
||||
Status CheckExternalState() const override {
|
||||
TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
||||
protected:
|
||||
|
@ -341,43 +341,6 @@ TEST_P(ParameterizedMapDatasetOpTest, Cardinality) {
|
||||
EXPECT_EQ(map_dataset->Cardinality(), test_case.expected_cardinality);
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedMapDatasetOpTest, DatasetSave) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TestCase test_case = GetParam();
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num));
|
||||
|
||||
DatasetBase* range_dataset;
|
||||
TF_ASSERT_OK(CreateRangeDataset<int64>(
|
||||
test_case.start, test_case.end, test_case.step, "range", &range_dataset));
|
||||
Tensor range_dataset_tensor(DT_VARIANT, TensorShape({}));
|
||||
// The ownership of range_dataset is transferred to DatasetVariantWrapper,
|
||||
// which will handle the release of memory.
|
||||
TF_ASSERT_OK(
|
||||
StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor));
|
||||
gtl::InlinedVector<TensorValue, 4> map_dataset_inputs;
|
||||
map_dataset_inputs.emplace_back(&range_dataset_tensor);
|
||||
|
||||
std::unique_ptr<OpKernel> map_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateMapDatasetOpKernel(
|
||||
test_case.func, test_case.expected_output_dtypes,
|
||||
test_case.expected_output_shapes, &map_dataset_kernel));
|
||||
std::unique_ptr<OpKernelContext> map_dataset_context;
|
||||
TF_ASSERT_OK(CreateMapDatasetContext(
|
||||
map_dataset_kernel.get(), &map_dataset_inputs, &map_dataset_context));
|
||||
DatasetBase* map_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(map_dataset_kernel.get(),
|
||||
map_dataset_context.get(), &map_dataset));
|
||||
core::ScopedUnref scoped_unref_map_dataset(map_dataset);
|
||||
|
||||
std::unique_ptr<SerializationContext> serialization_context;
|
||||
TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
|
||||
VariantTensorData data;
|
||||
VariantTensorDataWriter writer(&data);
|
||||
TF_ASSERT_OK(map_dataset->Save(serialization_context.get(), &writer));
|
||||
TF_ASSERT_OK(writer.Flush());
|
||||
}
|
||||
|
||||
TEST_F(MapDatasetOpTest, IteratorOutputDtypes) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TestCase test_case = TestCase1();
|
||||
|
@ -85,7 +85,9 @@ class ModelDatasetOp : public UnaryDatasetOpKernel {
|
||||
|
||||
int64 Cardinality() const override { return input_->Cardinality(); }
|
||||
|
||||
bool IsStateful() const override { return input_->IsStateful(); }
|
||||
Status CheckExternalState() const override {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
|
@ -114,7 +114,9 @@ class PaddedBatchDatasetOp::Dataset : public DatasetBase {
|
||||
return n / batch_size_ + (n % batch_size_ == 0 || drop_remainder_ ? 0 : 1);
|
||||
}
|
||||
|
||||
bool IsStateful() const override { return input_->IsStateful(); }
|
||||
Status CheckExternalState() const override {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
|
@ -785,53 +785,6 @@ TEST_P(ParameterizedPaddedBatchDatasetOpTest, Cardinality) {
|
||||
test_case.expected_cardinality);
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedPaddedBatchDatasetOpTest, DatasetSave) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
const TestCase &test_case = GetParam();
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
std::unique_ptr<OpKernel> padded_batch_dataset_kernel;
|
||||
TF_ASSERT_OK(CreatePaddedBatchDatasetKernel(
|
||||
test_case.parallel_copy, test_case.n, test_case.expected_output_dtypes,
|
||||
test_case.expected_output_shapes, &padded_batch_dataset_kernel));
|
||||
|
||||
Tensor concatenate_dataset_tensor(DT_VARIANT, TensorShape({}));
|
||||
TF_ASSERT_OK(CreateConcatenateDatasetTensor(
|
||||
test_case.input_tensors, test_case.concatenate_output_dtypes,
|
||||
test_case.concatenate_output_shapes, &concatenate_dataset_tensor));
|
||||
Tensor batch_size = test_case.batch_size;
|
||||
std::vector<Tensor> padded_shapes = test_case.padded_shapes;
|
||||
std::vector<Tensor> padding_values = test_case.padding_values;
|
||||
Tensor drop_remainder = test_case.drop_remainder;
|
||||
gtl::InlinedVector<TensorValue, 4> inputs(
|
||||
{TensorValue(&concatenate_dataset_tensor), TensorValue(&batch_size)});
|
||||
for (auto &padded_shape : padded_shapes) {
|
||||
inputs.emplace_back(&padded_shape);
|
||||
}
|
||||
for (auto &padding_value : padding_values) {
|
||||
inputs.emplace_back(&padding_value);
|
||||
}
|
||||
inputs.emplace_back(&drop_remainder);
|
||||
|
||||
std::unique_ptr<OpKernelContext> padded_batch_dataset_context;
|
||||
TF_ASSERT_OK(
|
||||
CreatePaddedBatchDatasetContext(padded_batch_dataset_kernel.get(),
|
||||
&inputs, &padded_batch_dataset_context));
|
||||
DatasetBase *padded_batch_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(padded_batch_dataset_kernel.get(),
|
||||
padded_batch_dataset_context.get(),
|
||||
&padded_batch_dataset));
|
||||
core::ScopedUnref scoped_unref(padded_batch_dataset);
|
||||
|
||||
std::unique_ptr<SerializationContext> serialization_ctx;
|
||||
TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
|
||||
VariantTensorData data;
|
||||
VariantTensorDataWriter writer(&data);
|
||||
TF_ASSERT_OK(padded_batch_dataset->Save(serialization_ctx.get(), &writer));
|
||||
TF_ASSERT_OK(writer.Flush());
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedPaddedBatchDatasetOpTest, IteratorOutputDtypes) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
const TestCase &test_case = GetParam();
|
||||
|
@ -153,6 +153,11 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
ParallelInterleaveDatasetOp::kDatasetType, params);
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
DatasetGraphDefBuilder* b,
|
||||
|
@ -724,47 +724,6 @@ TEST_P(ParameterizedParallelInterleaveDatasetOpTest, Cardinality) {
|
||||
test_case.expected_cardinality);
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedParallelInterleaveDatasetOpTest, DatasetSave) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
const TestCase &test_case = GetParam();
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num));
|
||||
|
||||
std::unique_ptr<OpKernel> parallel_interleave_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateParallelInterleaveDatasetKernel(
|
||||
test_case.func, test_case.expected_output_dtypes,
|
||||
test_case.expected_output_shapes, test_case.sloppy,
|
||||
¶llel_interleave_dataset_kernel));
|
||||
|
||||
Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({}));
|
||||
std::vector<Tensor> inputs_for_tensor_slice_dataset = test_case.input_tensors;
|
||||
TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset,
|
||||
&tensor_slice_dataset_tensor));
|
||||
Tensor cycle_length = test_case.cycle_length;
|
||||
Tensor block_length = test_case.block_length;
|
||||
Tensor num_parallel_calls = test_case.num_parallel_calls;
|
||||
gtl::InlinedVector<TensorValue, 4> inputs(
|
||||
{TensorValue(&tensor_slice_dataset_tensor), TensorValue(&cycle_length),
|
||||
TensorValue(&block_length), TensorValue(&num_parallel_calls)});
|
||||
std::unique_ptr<OpKernelContext> parallel_interleave_dataset_context;
|
||||
TF_ASSERT_OK(CreateInterleaveDatasetContext(
|
||||
parallel_interleave_dataset_kernel.get(), &inputs,
|
||||
¶llel_interleave_dataset_context));
|
||||
DatasetBase *parallel_interleave_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(parallel_interleave_dataset_kernel.get(),
|
||||
parallel_interleave_dataset_context.get(),
|
||||
¶llel_interleave_dataset));
|
||||
core::ScopedUnref scoped_unref(parallel_interleave_dataset);
|
||||
|
||||
std::unique_ptr<SerializationContext> serialization_ctx;
|
||||
TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
|
||||
VariantTensorData data;
|
||||
VariantTensorDataWriter writer(&data);
|
||||
TF_ASSERT_OK(
|
||||
parallel_interleave_dataset->Save(serialization_ctx.get(), &writer));
|
||||
TF_ASSERT_OK(writer.Flush());
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedParallelInterleaveDatasetOpTest, IteratorOutputDtypes) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
const TestCase &test_case = GetParam();
|
||||
|
@ -89,8 +89,9 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
|
||||
|
||||
int64 Cardinality() const override { return input_->Cardinality(); }
|
||||
|
||||
bool IsStateful() const override {
|
||||
return captured_func_->IsStateful() || input_->IsStateful();
|
||||
Status CheckExternalState() const override {
|
||||
TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
||||
protected:
|
||||
|
@ -527,49 +527,6 @@ TEST_P(ParameterizedParallelMapDatasetOpTest, Cardinality) {
|
||||
test_case.expected_cardinality);
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedParallelMapDatasetOpTest, DatasetSave) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TestCase test_case = GetParam();
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num));
|
||||
|
||||
std::unique_ptr<OpKernel> parallel_map_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateParallelMapDatasetOpKernel(
|
||||
test_case.func, test_case.expected_output_dtypes,
|
||||
test_case.expected_output_shapes, test_case.use_inter_op_parallelism,
|
||||
test_case.sloppy, test_case.preserve_cardinality,
|
||||
¶llel_map_dataset_kernel));
|
||||
|
||||
DatasetBase* range_dataset;
|
||||
TF_ASSERT_OK(CreateRangeDataset<int64>(
|
||||
test_case.range_data_param.start, test_case.range_data_param.end,
|
||||
test_case.range_data_param.step, "range", &range_dataset));
|
||||
Tensor range_dataset_tensor(DT_VARIANT, TensorShape({}));
|
||||
TF_ASSERT_OK(
|
||||
StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor));
|
||||
Tensor num_parallel_calls = test_case.num_parallel_calls;
|
||||
gtl::InlinedVector<TensorValue, 4> parallel_map_dataset_inputs(
|
||||
{TensorValue(&range_dataset_tensor), TensorValue(&num_parallel_calls)});
|
||||
|
||||
std::unique_ptr<OpKernelContext> parallel_map_dataset_context;
|
||||
TF_ASSERT_OK(CreateParallelMapDatasetContext(
|
||||
parallel_map_dataset_kernel.get(), ¶llel_map_dataset_inputs,
|
||||
¶llel_map_dataset_context));
|
||||
DatasetBase* parallel_map_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(parallel_map_dataset_kernel.get(),
|
||||
parallel_map_dataset_context.get(),
|
||||
¶llel_map_dataset));
|
||||
core::ScopedUnref scoped_unref_map_dataset(parallel_map_dataset);
|
||||
|
||||
std::unique_ptr<SerializationContext> serialization_context;
|
||||
TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
|
||||
VariantTensorData data;
|
||||
VariantTensorDataWriter writer(&data);
|
||||
TF_ASSERT_OK(
|
||||
parallel_map_dataset->Save(serialization_context.get(), &writer));
|
||||
TF_ASSERT_OK(writer.Flush());
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedParallelMapDatasetOpTest, IteratorOutputDtypes) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TestCase test_case = GetParam();
|
||||
|
@ -83,7 +83,9 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
|
||||
|
||||
int64 Cardinality() const override { return input_->Cardinality(); }
|
||||
|
||||
bool IsStateful() const override { return input_->IsStateful(); }
|
||||
Status CheckExternalState() const override {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
|
@ -398,43 +398,6 @@ TEST_P(ParameterizedPrefetchDatasetOpTest, Cardinality) {
|
||||
EXPECT_EQ(prefetch_dataset->Cardinality(), test_case.expected_cardinality);
|
||||
}
|
||||
|
||||
TEST_F(PrefetchDatasetOpTest, DatasetSave) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
const TestCase &test_case = PositiveBufferSizeTestCase();
|
||||
Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({}));
|
||||
std::vector<Tensor> inputs_for_tensor_slice_dataset = test_case.input_tensors;
|
||||
TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset,
|
||||
&tensor_slice_dataset_tensor));
|
||||
Tensor buffer_size =
|
||||
CreateTensor<int64>(TensorShape{}, {test_case.buffer_size});
|
||||
gtl::InlinedVector<TensorValue, 4> inputs_for_prefetch_dataset(
|
||||
{TensorValue(&tensor_slice_dataset_tensor), TensorValue(&buffer_size)});
|
||||
|
||||
std::unique_ptr<OpKernel> prefetch_dataset_kernel;
|
||||
TF_ASSERT_OK(CreatePrefetchDatasetKernel(test_case.expected_output_dtypes,
|
||||
test_case.expected_output_shapes,
|
||||
&prefetch_dataset_kernel));
|
||||
std::unique_ptr<OpKernelContext> prefetch_dataset_context;
|
||||
TF_ASSERT_OK(CreatePrefetchDatasetContext(prefetch_dataset_kernel.get(),
|
||||
&inputs_for_prefetch_dataset,
|
||||
&prefetch_dataset_context));
|
||||
DatasetBase *prefetch_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(prefetch_dataset_kernel.get(),
|
||||
prefetch_dataset_context.get(),
|
||||
&prefetch_dataset));
|
||||
core::ScopedUnref scoped_unref(prefetch_dataset);
|
||||
|
||||
std::unique_ptr<SerializationContext> serialization_ctx;
|
||||
TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
|
||||
VariantTensorData data;
|
||||
VariantTensorDataWriter writer(&data);
|
||||
TF_ASSERT_OK(prefetch_dataset->Save(serialization_ctx.get(), &writer));
|
||||
TF_ASSERT_OK(writer.Flush());
|
||||
}
|
||||
|
||||
TEST_F(PrefetchDatasetOpTest, IteratorOutputDtypes) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
|
@ -73,6 +73,8 @@ class RangeDatasetOp::Dataset : public DatasetBase {
|
||||
}
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override { return Status::OK(); }
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
DatasetGraphDefBuilder* b,
|
||||
|
@ -299,28 +299,6 @@ DatasetSaveTestCase<RangeDatasetParams> DatasetSaveTestCase1() {
|
||||
return {/*dataset_params=*/PositiveStepRangeDataset()};
|
||||
}
|
||||
|
||||
TEST_F(RangeDatasetOpTest, DatasetSave) {
|
||||
int64 thread_num = 2, cpu_num = 2;
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
auto test_case = DatasetSaveTestCase1();
|
||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||
TF_ASSERT_OK(test_case.dataset_params.MakeInputs(&inputs));
|
||||
std::unique_ptr<OpKernel> range_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateRangeDatasetOpKernel<int64>(
|
||||
test_case.dataset_params.node_name, &range_dataset_kernel));
|
||||
std::unique_ptr<OpKernelContext> range_dataset_context;
|
||||
TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
|
||||
&range_dataset_context));
|
||||
DatasetBase* range_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(),
|
||||
range_dataset_context.get(), &range_dataset));
|
||||
core::ScopedUnref scoped_unref(range_dataset);
|
||||
|
||||
TF_ASSERT_OK(CheckDatasetSave(*range_dataset));
|
||||
}
|
||||
|
||||
IsStatefulTestCase<RangeDatasetParams> IsStatefulTestCase1() {
|
||||
return {/*dataset_params=*/PositiveStepRangeDataset(),
|
||||
/*expected_stateful=*/false};
|
||||
|
@ -89,7 +89,9 @@ class RepeatDatasetOp::Dataset : public DatasetBase {
|
||||
return count_ * n;
|
||||
}
|
||||
|
||||
bool IsStateful() const override { return input_->IsStateful(); }
|
||||
Status CheckExternalState() const override {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
|
@ -349,41 +349,6 @@ TEST_P(ParameterizedDatasetOpTest, Cardinality) {
|
||||
EXPECT_EQ(repeat_dataset->Cardinality(), GetParam().expected_cardinality);
|
||||
}
|
||||
|
||||
TEST_F(RepeatDatasetOpTest, DatasetSave) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
const TestCase &test_case = FiniteRepeatTestCase();
|
||||
Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({}));
|
||||
std::vector<Tensor> inputs_for_tensor_slice_dataset = test_case.input_tensors;
|
||||
TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset,
|
||||
&tensor_slice_dataset_tensor));
|
||||
Tensor count = CreateTensor<int64>(TensorShape{}, {test_case.count});
|
||||
gtl::InlinedVector<TensorValue, 4> inputs_for_repeat_dataset;
|
||||
inputs_for_repeat_dataset.emplace_back(&tensor_slice_dataset_tensor);
|
||||
inputs_for_repeat_dataset.emplace_back(&count);
|
||||
|
||||
std::unique_ptr<OpKernel> repeat_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateRepeatDatasetKernel(test_case.expected_output_dtypes,
|
||||
test_case.expected_output_shapes,
|
||||
&repeat_dataset_kernel));
|
||||
std::unique_ptr<OpKernelContext> repeat_dataset_context;
|
||||
TF_ASSERT_OK(CreateRepeatDatasetContext(repeat_dataset_kernel.get(),
|
||||
&inputs_for_repeat_dataset,
|
||||
&repeat_dataset_context));
|
||||
DatasetBase *repeat_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(repeat_dataset_kernel.get(),
|
||||
repeat_dataset_context.get(), &repeat_dataset));
|
||||
core::ScopedUnref scoped_unref(repeat_dataset);
|
||||
|
||||
std::unique_ptr<SerializationContext> serialization_ctx;
|
||||
TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
|
||||
VariantTensorData data;
|
||||
VariantTensorDataWriter writer(&data);
|
||||
TF_ASSERT_OK(repeat_dataset->Save(serialization_ctx.get(), &writer));
|
||||
TF_ASSERT_OK(writer.Flush());
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedDatasetOpTest, IteratorOutputDtypes) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
|
@ -79,7 +79,9 @@ class ShardDatasetOp::Dataset : public DatasetBase {
|
||||
return n / num_shards_ + (index_ < n % num_shards_ ? 1 : 0);
|
||||
}
|
||||
|
||||
bool IsStateful() const override { return input_->IsStateful(); }
|
||||
Status CheckExternalState() const override {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
|
@ -497,47 +497,6 @@ TEST_P(ParameterizedShardDatasetOpTest, Cardinality) {
|
||||
EXPECT_EQ(shard_dataset->Cardinality(), test_case.expected_cardinality);
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedShardDatasetOpTest, DatasetSave) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TestCase test_case = GetParam();
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
std::unique_ptr<OpKernel> shard_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateShardDatasetOpKernel(
|
||||
test_case.require_non_empty, test_case.expected_output_dtypes,
|
||||
test_case.expected_output_shapes, &shard_dataset_kernel));
|
||||
|
||||
DatasetBase* range_dataset;
|
||||
TF_ASSERT_OK(CreateRangeDataset<int64>(
|
||||
test_case.range_dataset_param.start, test_case.range_dataset_param.end,
|
||||
test_case.range_dataset_param.step, "range", &range_dataset));
|
||||
Tensor range_dataset_tensor(DT_VARIANT, TensorShape({}));
|
||||
TF_ASSERT_OK(
|
||||
StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor));
|
||||
|
||||
Tensor num_shards = test_case.num_shards;
|
||||
Tensor index = test_case.index;
|
||||
gtl::InlinedVector<TensorValue, 4> inputs({TensorValue(&range_dataset_tensor),
|
||||
TensorValue(&num_shards),
|
||||
TensorValue(&index)});
|
||||
std::unique_ptr<OpKernelContext> shard_dataset_context;
|
||||
TF_ASSERT_OK(CreateShardDatasetContext(shard_dataset_kernel.get(), &inputs,
|
||||
&shard_dataset_context));
|
||||
|
||||
DatasetBase* shard_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(shard_dataset_kernel.get(),
|
||||
shard_dataset_context.get(), &shard_dataset));
|
||||
core::ScopedUnref scoped_unref_batch_dataset(shard_dataset);
|
||||
|
||||
std::unique_ptr<SerializationContext> serialization_context;
|
||||
TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
|
||||
VariantTensorData data;
|
||||
VariantTensorDataWriter writer(&data);
|
||||
TF_ASSERT_OK(shard_dataset->Save(serialization_context.get(), &writer));
|
||||
TF_ASSERT_OK(writer.Flush());
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedShardDatasetOpTest, IteratorOutputDtypes) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TestCase test_case = GetParam();
|
||||
|
@ -103,7 +103,9 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase {
|
||||
}
|
||||
}
|
||||
|
||||
bool IsStateful() const override { return input_->IsStateful(); }
|
||||
Status CheckExternalState() const override {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
||||
protected:
|
||||
template <class T>
|
||||
@ -540,7 +542,10 @@ class ShuffleDatasetOp::ReshufflingDatasetV2 : public ShuffleDatasetBase {
|
||||
return name_utils::DatasetDebugString(kDatasetType, params);
|
||||
}
|
||||
|
||||
bool IsStateful() const override { return true; }
|
||||
Status CheckExternalState() const override {
|
||||
return errors::FailedPrecondition(
|
||||
DebugString(), " depends on random seed generator resource.");
|
||||
}
|
||||
|
||||
std::unique_ptr<IteratorBase> MakeIteratorInternal(
|
||||
const string& prefix) const override {
|
||||
|
@ -606,51 +606,6 @@ TEST_P(ParameterizedShuffleDatasetOpTest, Cardinality) {
|
||||
EXPECT_EQ(dataset->Cardinality(), test_case.expected_cardinality);
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedShuffleDatasetOpTest, DatasetSave) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TestCase test_case = GetParam();
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
Tensor count = test_case.count;
|
||||
int64 count_value = count.flat<int64>()(0);
|
||||
std::unique_ptr<OpKernel> dataset_kernel;
|
||||
TF_ASSERT_OK(
|
||||
CreateDatasetOpKernel(count_value, test_case.reshuffle_each_iteration,
|
||||
test_case.expected_output_dtypes,
|
||||
test_case.expected_output_shapes, &dataset_kernel));
|
||||
|
||||
DatasetBase* range_dataset;
|
||||
TF_ASSERT_OK(CreateRangeDataset<int64>(
|
||||
test_case.range_data_param.start, test_case.range_data_param.end,
|
||||
test_case.range_data_param.step, "range", &range_dataset));
|
||||
Tensor range_dataset_tensor(DT_VARIANT, TensorShape({}));
|
||||
TF_ASSERT_OK(
|
||||
StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor));
|
||||
Tensor buffer_size = test_case.buffer_size;
|
||||
Tensor seed = test_case.seed;
|
||||
Tensor seed2 = test_case.seed2;
|
||||
gtl::InlinedVector<TensorValue, 4> inputs(
|
||||
{TensorValue(&range_dataset_tensor), TensorValue(&buffer_size),
|
||||
TensorValue(&seed), TensorValue(&seed2)});
|
||||
if (count_value != 1) inputs.push_back(TensorValue(&count));
|
||||
|
||||
std::unique_ptr<OpKernelContext> dataset_context;
|
||||
TF_ASSERT_OK(
|
||||
CreateDatasetContext(dataset_kernel.get(), &inputs, &dataset_context));
|
||||
DatasetBase* dataset;
|
||||
TF_ASSERT_OK(
|
||||
CreateDataset(dataset_kernel.get(), dataset_context.get(), &dataset));
|
||||
core::ScopedUnref scoped_unref_dataset(dataset);
|
||||
|
||||
std::unique_ptr<SerializationContext> serialization_context;
|
||||
TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
|
||||
VariantTensorData data;
|
||||
VariantTensorDataWriter writer(&data);
|
||||
TF_ASSERT_OK(dataset->Save(serialization_context.get(), &writer));
|
||||
TF_ASSERT_OK(writer.Flush());
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedShuffleDatasetOpTest, IteratorOutputDtypes) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TestCase test_case = GetParam();
|
||||
|
@ -75,7 +75,9 @@ class SkipDatasetOp::Dataset : public DatasetBase {
|
||||
return count_ < 0 ? 0 : std::max(0LL, n - count_);
|
||||
}
|
||||
|
||||
bool IsStateful() const override { return input_->IsStateful(); }
|
||||
Status CheckExternalState() const override {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
|
@ -356,41 +356,6 @@ TEST_P(ParameterizedSkipDatasetOpTest, Cardinality) {
|
||||
EXPECT_EQ(skip_dataset->Cardinality(), test_case.expected_cardinality);
|
||||
}
|
||||
|
||||
TEST_F(SkipDatasetOpTest, DatasetSave) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
const TestCase &test_case = SkipLessTestCase();
|
||||
Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({}));
|
||||
std::vector<Tensor> inputs_for_tensor_slice_dataset = test_case.input_tensors;
|
||||
TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset,
|
||||
&tensor_slice_dataset_tensor));
|
||||
Tensor count = CreateTensor<int64>(TensorShape{}, {test_case.count});
|
||||
gtl::InlinedVector<TensorValue, 4> inputs_for_skip_dataset(
|
||||
{TensorValue(&tensor_slice_dataset_tensor), TensorValue(&count)});
|
||||
|
||||
std::unique_ptr<OpKernel> skip_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateSkipDatasetKernel(test_case.expected_output_dtypes,
|
||||
test_case.expected_output_shapes,
|
||||
&skip_dataset_kernel));
|
||||
std::unique_ptr<OpKernelContext> skip_dataset_context;
|
||||
TF_ASSERT_OK(CreateSkipDatasetContext(skip_dataset_kernel.get(),
|
||||
&inputs_for_skip_dataset,
|
||||
&skip_dataset_context));
|
||||
DatasetBase *skip_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(skip_dataset_kernel.get(),
|
||||
skip_dataset_context.get(), &skip_dataset));
|
||||
core::ScopedUnref scoped_unref(skip_dataset);
|
||||
|
||||
std::unique_ptr<SerializationContext> serialization_ctx;
|
||||
TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
|
||||
VariantTensorData data;
|
||||
VariantTensorDataWriter writer(&data);
|
||||
TF_ASSERT_OK(skip_dataset->Save(serialization_ctx.get(), &writer));
|
||||
TF_ASSERT_OK(writer.Flush());
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedSkipDatasetOpTest, IteratorOutputDtypes) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
|
@ -56,6 +56,8 @@ class Dataset : public DatasetBase {
|
||||
|
||||
int64 Cardinality() const override { return sparse_tensor_.shape()[0]; }
|
||||
|
||||
Status CheckExternalState() const override { return Status::OK(); }
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
DatasetGraphDefBuilder* b,
|
||||
|
@ -324,38 +324,6 @@ TEST_P(ParameterizedSparseTensorSliceDatasetOpTest, Cardinality) {
|
||||
EXPECT_EQ(dataset->Cardinality(), expected_outputs.size());
|
||||
}
|
||||
|
||||
TEST_F(SparseTensorSliceDatasetOpTest, DatasetSave) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
const TestCase &test_case = TwoDimsTestCase();
|
||||
SparseTensorParam input_sparse_tensor = test_case.input_sparse_tensor;
|
||||
std::vector<SparseTensorParam> expected_outputs = test_case.expected_outputs;
|
||||
DataType tvalues = input_sparse_tensor.values.dtype();
|
||||
gtl::InlinedVector<TensorValue, 4> inputs = {
|
||||
TensorValue(&input_sparse_tensor.indices),
|
||||
TensorValue(&input_sparse_tensor.values),
|
||||
TensorValue(&input_sparse_tensor.dense_shape)};
|
||||
|
||||
std::unique_ptr<OpKernel> dataset_kernel;
|
||||
TF_ASSERT_OK(CreateSparseTensorSliceDatasetKernel(tvalues, &dataset_kernel));
|
||||
std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
|
||||
TF_ASSERT_OK(CreateSparseTensorSliceDatasetContext(
|
||||
dataset_kernel.get(), &inputs, &dataset_kernel_ctx));
|
||||
DatasetBase *dataset;
|
||||
TF_ASSERT_OK(
|
||||
CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), &dataset));
|
||||
core::ScopedUnref scoped_unref(dataset);
|
||||
|
||||
std::unique_ptr<SerializationContext> serialization_ctx;
|
||||
TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
|
||||
VariantTensorData data;
|
||||
VariantTensorDataWriter writer(&data);
|
||||
TF_ASSERT_OK(dataset->Save(serialization_ctx.get(), &writer));
|
||||
TF_ASSERT_OK(writer.Flush());
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedSparseTensorSliceDatasetOpTest, IteratorOutputDtypes) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
|
@ -71,7 +71,9 @@ int64 TakeDataset::Cardinality() const {
|
||||
return std::min(n, count_);
|
||||
}
|
||||
|
||||
bool TakeDataset::IsStateful() const { return input_->IsStateful(); }
|
||||
Status TakeDataset::CheckExternalState() const {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
||||
class TakeDataset::EmptyIterator : public DatasetIterator<TakeDataset> {
|
||||
public:
|
||||
|
@ -40,7 +40,7 @@ class TakeDataset : public DatasetBase {
|
||||
|
||||
int64 Cardinality() const override;
|
||||
|
||||
bool IsStateful() const override;
|
||||
Status CheckExternalState() const override;
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
|
@ -351,41 +351,6 @@ TEST_P(ParameterizedTakeDatasetOpTest, Cardinality) {
|
||||
EXPECT_EQ(take_dataset->Cardinality(), test_case.expected_cardinality);
|
||||
}
|
||||
|
||||
TEST_F(TakeDatasetOpTest, DatasetSave) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
const TestCase &test_case = TakeLessTestCase();
|
||||
Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({}));
|
||||
std::vector<Tensor> inputs_for_tensor_slice_dataset = test_case.input_tensors;
|
||||
TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset,
|
||||
&tensor_slice_dataset_tensor));
|
||||
Tensor count = CreateTensor<int64>(TensorShape{}, {test_case.count});
|
||||
gtl::InlinedVector<TensorValue, 4> inputs_for_take_dataset;
|
||||
inputs_for_take_dataset.emplace_back(&tensor_slice_dataset_tensor);
|
||||
inputs_for_take_dataset.emplace_back(&count);
|
||||
|
||||
std::unique_ptr<OpKernel> take_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateTakeDatasetKernel(test_case.expected_output_dtypes,
|
||||
test_case.expected_output_shapes,
|
||||
&take_dataset_kernel));
|
||||
std::unique_ptr<OpKernelContext> take_dataset_context;
|
||||
TF_ASSERT_OK(CreateTakeDatasetContext(take_dataset_kernel.get(),
|
||||
&inputs_for_take_dataset,
|
||||
&take_dataset_context));
|
||||
DatasetBase *take_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(take_dataset_kernel.get(),
|
||||
take_dataset_context.get(), &take_dataset));
|
||||
core::ScopedUnref scoped_unref(take_dataset);
|
||||
|
||||
std::unique_ptr<SerializationContext> serialization_ctx;
|
||||
TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
|
||||
VariantTensorData data;
|
||||
VariantTensorDataWriter writer(&data);
|
||||
TF_ASSERT_OK(take_dataset->Save(serialization_ctx.get(), &writer));
|
||||
TF_ASSERT_OK(writer.Flush());
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedTakeDatasetOpTest, IteratorOutputDtypes) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
|
@ -62,6 +62,8 @@ class TensorDatasetOp::Dataset : public DatasetBase {
|
||||
|
||||
int64 Cardinality() const override { return 1LL; }
|
||||
|
||||
Status CheckExternalState() const override { return Status::OK(); }
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
DatasetGraphDefBuilder* b,
|
||||
@ -70,12 +72,12 @@ class TensorDatasetOp::Dataset : public DatasetBase {
|
||||
components.reserve(tensors_.size());
|
||||
for (const Tensor& t : tensors_) {
|
||||
Node* node;
|
||||
if (ctx->optimization_only()) {
|
||||
if (ctx->serialize_data_tensors()) {
|
||||
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(b->AddPlaceholder(t, &node));
|
||||
DCHECK_NE(ctx->input_list(), nullptr);
|
||||
ctx->input_list()->emplace_back(node->name(), t);
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
|
||||
}
|
||||
components.emplace_back(node);
|
||||
}
|
||||
|
@ -305,37 +305,6 @@ TEST_F(TensorDatasetOpTest, Cardinality) {
|
||||
EXPECT_EQ(tensor_dataset->Cardinality(), test_case.expected_cardinality);
|
||||
}
|
||||
|
||||
TEST_P(ParametrizedTensorDatasetOpTest, DatasetSave) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
const TestCase &test_case = GetParam();
|
||||
std::vector<Tensor> components = test_case.components;
|
||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||
for (auto &component : components) {
|
||||
inputs.push_back(TensorValue(&component));
|
||||
}
|
||||
std::unique_ptr<OpKernel> tensor_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateTensorDatasetKernel(test_case.expected_output_dtypes,
|
||||
test_case.expected_output_shapes,
|
||||
&tensor_dataset_kernel));
|
||||
std::unique_ptr<OpKernelContext> tensor_dataset_context;
|
||||
TF_ASSERT_OK(CreateTensorDatasetContext(tensor_dataset_kernel.get(), &inputs,
|
||||
&tensor_dataset_context));
|
||||
DatasetBase *tensor_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(tensor_dataset_kernel.get(),
|
||||
tensor_dataset_context.get(), &tensor_dataset));
|
||||
core::ScopedUnref scoped_unref(tensor_dataset);
|
||||
|
||||
std::unique_ptr<SerializationContext> serialization_context;
|
||||
TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
|
||||
VariantTensorData data;
|
||||
VariantTensorDataWriter writer(&data);
|
||||
TF_ASSERT_OK(tensor_dataset->Save(serialization_context.get(), &writer));
|
||||
TF_ASSERT_OK(writer.Flush());
|
||||
}
|
||||
|
||||
TEST_P(ParametrizedTensorDatasetOpTest, IteratorOutputDtypes) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
|
@ -68,6 +68,8 @@ class TensorSliceDatasetOp::Dataset : public DatasetBase {
|
||||
|
||||
int64 Cardinality() const override { return tensors_[0].dim_size(0); }
|
||||
|
||||
Status CheckExternalState() const override { return Status::OK(); }
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
DatasetGraphDefBuilder* b,
|
||||
@ -76,12 +78,12 @@ class TensorSliceDatasetOp::Dataset : public DatasetBase {
|
||||
components.reserve(tensors_.size());
|
||||
for (const Tensor& t : tensors_) {
|
||||
Node* node;
|
||||
if (ctx->optimization_only()) {
|
||||
if (ctx->serialize_data_tensors()) {
|
||||
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(b->AddPlaceholder(t, &node));
|
||||
DCHECK_NE(ctx->input_list(), nullptr);
|
||||
ctx->input_list()->emplace_back(node->name(), t);
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
|
||||
}
|
||||
components.emplace_back(node);
|
||||
}
|
||||
|
@ -387,48 +387,6 @@ TEST_P(ParameterizedTensorSliceDatasetOpTest, Cardinality) {
|
||||
EXPECT_EQ(tensor_slice_dataset->Cardinality(), inputs[0].tensor->dim_size(0));
|
||||
}
|
||||
|
||||
TEST_F(TensorSliceDatasetOpTest, DatasetSave) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
const TestCase &test_case = PlainTensorTestCase();
|
||||
const std::vector<Tensor> &expected_outputs = test_case.expected_outputs;
|
||||
std::vector<Tensor> components = test_case.components;
|
||||
DataTypeVector dtypes;
|
||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||
for (auto &component : components) {
|
||||
inputs.emplace_back(&component);
|
||||
dtypes.emplace_back(component.dtype());
|
||||
}
|
||||
size_t num_tensors_per_slice = components.size();
|
||||
std::vector<PartialTensorShape> shapes;
|
||||
shapes.reserve(num_tensors_per_slice);
|
||||
for (int i = 0; i < num_tensors_per_slice; ++i) {
|
||||
shapes.emplace_back(expected_outputs[i].shape());
|
||||
}
|
||||
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
|
||||
&tensor_slice_dataset_kernel));
|
||||
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
|
||||
TF_ASSERT_OK(
|
||||
CreateTensorSliceDatasetContext(tensor_slice_dataset_kernel.get(),
|
||||
&inputs, &tensor_slice_dataset_context));
|
||||
DatasetBase *tensor_slice_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
|
||||
tensor_slice_dataset_context.get(),
|
||||
&tensor_slice_dataset));
|
||||
core::ScopedUnref scoped_unref(tensor_slice_dataset);
|
||||
|
||||
std::unique_ptr<SerializationContext> serialization_context;
|
||||
TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
|
||||
VariantTensorData data;
|
||||
VariantTensorDataWriter writer(&data);
|
||||
TF_ASSERT_OK(
|
||||
tensor_slice_dataset->Save(serialization_context.get(), &writer));
|
||||
TF_ASSERT_OK(writer.Flush());
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedTensorSliceDatasetOpTest, IteratorOutputDtypes) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
|
@ -70,6 +70,8 @@ class TextLineDatasetOp::Dataset : public DatasetBase {
|
||||
return name_utils::DatasetDebugString(kDatasetType);
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override { return Status::OK(); }
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
DatasetGraphDefBuilder* b,
|
||||
|
@ -366,45 +366,6 @@ TEST_P(ParameterizedTextLineDatasetOpTest, Cardinality) {
|
||||
EXPECT_EQ(text_line_dataset->Cardinality(), test_case.expected_cardinality);
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedTextLineDatasetOpTest, DatasetSave) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TestCase test_case = GetParam();
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
TF_ASSERT_OK(CreateTestFiles(test_case));
|
||||
|
||||
std::unique_ptr<OpKernel> text_line_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateTextLineDatasetOpKernel(&text_line_dataset_kernel));
|
||||
|
||||
int64 num_files = test_case.filenames.size();
|
||||
Tensor filenames =
|
||||
CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
|
||||
Tensor compression_type = CreateTensor<string>(
|
||||
TensorShape({}), {ToString(test_case.compression_type)});
|
||||
Tensor buffer_size =
|
||||
CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
|
||||
gtl::InlinedVector<TensorValue, 4> inputs{TensorValue(&filenames),
|
||||
TensorValue(&compression_type),
|
||||
TensorValue(&buffer_size)};
|
||||
std::unique_ptr<OpKernelContext> text_line_dataset_context;
|
||||
TF_ASSERT_OK(CreateTextLineDatasetContext(
|
||||
text_line_dataset_kernel.get(), &inputs, &text_line_dataset_context));
|
||||
|
||||
DatasetBase* text_line_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(text_line_dataset_kernel.get(),
|
||||
text_line_dataset_context.get(),
|
||||
&text_line_dataset));
|
||||
core::ScopedUnref scoped_unref(text_line_dataset);
|
||||
|
||||
std::unique_ptr<SerializationContext> serialization_context;
|
||||
TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
|
||||
VariantTensorData data;
|
||||
VariantTensorDataWriter writer(&data);
|
||||
TF_ASSERT_OK(text_line_dataset->Save(serialization_context.get(), &writer));
|
||||
TF_ASSERT_OK(writer.Flush());
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedTextLineDatasetOpTest, IteratorOutputDtypes) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TestCase test_case = GetParam();
|
||||
|
@ -74,6 +74,8 @@ class TFRecordDatasetOp::Dataset : public DatasetBase {
|
||||
return name_utils::DatasetDebugString(kDatasetType);
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override { return Status::OK(); }
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
DatasetGraphDefBuilder* b,
|
||||
|
@ -361,45 +361,6 @@ TEST_P(ParameterizedTFRecordDatasetOpTest, Cardinality) {
|
||||
EXPECT_EQ(tf_record_dataset->Cardinality(), test_case.expected_cardinality);
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedTFRecordDatasetOpTest, DatasetSave) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TestCase test_case = GetParam();
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
TF_ASSERT_OK(CreateTestFiles(test_case));
|
||||
|
||||
std::unique_ptr<OpKernel> tf_record_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateTFRecordDatasetOpKernel(&tf_record_dataset_kernel));
|
||||
|
||||
int64 num_files = test_case.filenames.size();
|
||||
Tensor filenames =
|
||||
CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
|
||||
Tensor compression_type = CreateTensor<string>(
|
||||
TensorShape({}), {ToString(test_case.compression_type)});
|
||||
Tensor buffer_size =
|
||||
CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
|
||||
gtl::InlinedVector<TensorValue, 4> inputs{TensorValue(&filenames),
|
||||
TensorValue(&compression_type),
|
||||
TensorValue(&buffer_size)};
|
||||
std::unique_ptr<OpKernelContext> tf_record_dataset_context;
|
||||
TF_ASSERT_OK(CreateTFRecordDatasetContext(
|
||||
tf_record_dataset_kernel.get(), &inputs, &tf_record_dataset_context));
|
||||
|
||||
DatasetBase* tf_record_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(tf_record_dataset_kernel.get(),
|
||||
tf_record_dataset_context.get(),
|
||||
&tf_record_dataset));
|
||||
core::ScopedUnref scoped_unref(tf_record_dataset);
|
||||
|
||||
std::unique_ptr<SerializationContext> serialization_context;
|
||||
TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
|
||||
VariantTensorData data;
|
||||
VariantTensorDataWriter writer(&data);
|
||||
TF_ASSERT_OK(tf_record_dataset->Save(serialization_context.get(), &writer));
|
||||
TF_ASSERT_OK(writer.Flush());
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedTFRecordDatasetOpTest, IteratorOutputDtypes) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TestCase test_case = GetParam();
|
||||
|
@ -59,6 +59,8 @@ class WindowDataset : public DatasetBase {
|
||||
|
||||
string DebugString() const override { return kWindowDataset; }
|
||||
|
||||
Status CheckExternalState() const override { return Status::OK(); }
|
||||
|
||||
protected:
|
||||
// TODO(b/110981596): Support checkpointing.
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
|
@ -99,7 +99,9 @@ class WindowDatasetOp::Dataset : public DatasetBase {
|
||||
return cardinality;
|
||||
}
|
||||
|
||||
bool IsStateful() const override { return input_->IsStateful(); }
|
||||
Status CheckExternalState() const override {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
|
@ -574,49 +574,6 @@ TEST_P(ParameterizedWindowDatasetOpTest, Cardinality) {
|
||||
EXPECT_EQ(dataset->Cardinality(), test_case.expected_cardinality);
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedWindowDatasetOpTest, DatasetSave) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TestCase test_case = GetParam();
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
std::unique_ptr<OpKernel> window_dataset_kernel;
|
||||
TF_ASSERT_OK(CreateWindowDatasetKernel(test_case.expected_output_dtypes,
|
||||
test_case.expected_output_shapes,
|
||||
&window_dataset_kernel));
|
||||
|
||||
DatasetBase* range_dataset;
|
||||
TF_ASSERT_OK(CreateRangeDataset<int64>(
|
||||
test_case.range_data_param.start, test_case.range_data_param.end,
|
||||
test_case.range_data_param.step, "range", &range_dataset));
|
||||
Tensor range_dataset_tensor(DT_VARIANT, TensorShape({}));
|
||||
TF_ASSERT_OK(
|
||||
StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor));
|
||||
Tensor size = test_case.size;
|
||||
Tensor shift = test_case.shift;
|
||||
Tensor stride = test_case.stride;
|
||||
Tensor drop_remainder = test_case.drop_remainder;
|
||||
gtl::InlinedVector<TensorValue, 4> inputs(
|
||||
{TensorValue(&range_dataset_tensor), TensorValue(&size),
|
||||
TensorValue(&shift), TensorValue(&stride),
|
||||
TensorValue(&drop_remainder)});
|
||||
|
||||
std::unique_ptr<OpKernelContext> window_dataset_op_ctx;
|
||||
TF_ASSERT_OK(CreateWindowDatasetContext(window_dataset_kernel.get(), &inputs,
|
||||
&window_dataset_op_ctx));
|
||||
DatasetBase* dataset;
|
||||
TF_ASSERT_OK(CreateDataset(window_dataset_kernel.get(),
|
||||
window_dataset_op_ctx.get(), &dataset));
|
||||
core::ScopedUnref scoped_unref_dataset(dataset);
|
||||
|
||||
std::unique_ptr<SerializationContext> serialization_context;
|
||||
TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
|
||||
VariantTensorData data;
|
||||
VariantTensorDataWriter writer(&data);
|
||||
TF_ASSERT_OK(dataset->Save(serialization_context.get(), &writer));
|
||||
TF_ASSERT_OK(writer.Flush());
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedWindowDatasetOpTest, IteratorOutputDtypes) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TestCase test_case = GetParam();
|
||||
|
@ -87,6 +87,13 @@ class ZipDatasetOp::Dataset : public DatasetBase {
|
||||
return result;
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
for (const auto& input : inputs_) {
|
||||
TF_RETURN_IF_ERROR(input->CheckExternalState());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
DatasetGraphDefBuilder* b,
|
||||
@ -103,15 +110,6 @@ class ZipDatasetOp::Dataset : public DatasetBase {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool IsStateful() const override {
|
||||
for (const auto& input : inputs_) {
|
||||
if (input->IsStateful()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private:
|
||||
class Iterator : public DatasetIterator<Dataset> {
|
||||
public:
|
||||
|
@ -333,41 +333,6 @@ TEST_P(ParameterizedZipDatasetOpTest, Cardinality) {
|
||||
test_case.expected_outputs.size() / num_tensors_per_slice);
|
||||
}
|
||||
|
||||
TEST_F(ZipDatasetOpTest, DatasetSave) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
|
||||
|
||||
const TestParam &test_case = TestCase1();
|
||||
std::vector<Tensor> range_dataset_tensors;
|
||||
range_dataset_tensors.reserve(test_case.input_range_dataset_params.size());
|
||||
TF_ASSERT_OK(CreateRangeDatasetTensors(test_case.input_range_dataset_params,
|
||||
&range_dataset_tensors));
|
||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||
inputs.reserve(range_dataset_tensors.size());
|
||||
for (auto &tensor : range_dataset_tensors) {
|
||||
inputs.emplace_back(&tensor);
|
||||
}
|
||||
std::unique_ptr<OpKernel> dataset_kernel;
|
||||
int num_tensors_per_slice = test_case.input_range_dataset_params.size();
|
||||
TF_ASSERT_OK(CreateZipDatasetKernel({DT_INT64}, {{num_tensors_per_slice}},
|
||||
inputs.size(), &dataset_kernel));
|
||||
std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
|
||||
TF_ASSERT_OK(CreateZipDatasetContext(dataset_kernel.get(), &inputs,
|
||||
&dataset_kernel_ctx));
|
||||
DatasetBase *zip_dataset;
|
||||
TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(),
|
||||
&zip_dataset));
|
||||
core::ScopedUnref scoped_unref(zip_dataset);
|
||||
|
||||
std::unique_ptr<SerializationContext> serialization_ctx;
|
||||
TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
|
||||
VariantTensorData data;
|
||||
VariantTensorDataWriter writer(&data);
|
||||
TF_ASSERT_OK(zip_dataset->Save(serialization_ctx.get(), &writer));
|
||||
TF_ASSERT_OK(writer.Flush());
|
||||
}
|
||||
|
||||
TEST_P(ParameterizedZipDatasetOpTest, IteratorOutputDtypes) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
|
@ -33,6 +33,7 @@ from tensorflow.python.platform import test
|
||||
class ReplicateClusterTest(test_base.DatasetTestBase):
|
||||
|
||||
def setUp(self):
|
||||
super(ReplicateClusterTest, self).setUp()
|
||||
# Start the local server.
|
||||
worker_config = config_pb2.ConfigProto()
|
||||
worker_config.device_count["CPU"] = 2
|
||||
@ -99,7 +100,7 @@ class ReplicateClusterTest(test_base.DatasetTestBase):
|
||||
it1 = dataset_ops.make_initializable_iterator(dataset1)
|
||||
# We don't support stateful ops in functions as of now.
|
||||
with session.Session(self._target) as sess:
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
with self.assertRaises(errors.FailedPreconditionError):
|
||||
sess.run(it1.initializer)
|
||||
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user