[tf.data] Serialization and checkpointing related cleanup.

This CL:
- removes unused `DatasetBase::Save()` and related tests
- replaces `SerilizationContext::optimization_only` with multiple functionality specific flags (`check_external_state`, `fail_if_unimplemented`, and `serialize_data_tensors`)
- introduces `DatasetBase::CheckExternalState` as an error-raising replacement for `DatasetBase::IsStateful` to make it possible to communicate the reason for why serialization failed through the error status
- adds `IteratorBase::SaveInternal` and `IteratorBase::RestoreInternal` in preparation of making these methods pure virtual

PiperOrigin-RevId: 262235093
This commit is contained in:
Jiri Simsa 2019-08-07 16:04:47 -07:00 committed by TensorFlower Gardener
parent 43a408b8ac
commit 6d8f05acd7
102 changed files with 405 additions and 1271 deletions

View File

@ -115,6 +115,15 @@ class BigtableReaderDatasetIterator : public DatasetIterator<Dataset> {
const ::google::cloud::bigtable::Row& row,
std::vector<Tensor>* out_tensors) = 0;
Status SaveInternal(IteratorStateWriter* writer) override {
return errors::Unimplemented("SaveInternal is currently not supported");
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
return errors::Unimplemented("RestoreInternal is currently not supported");
}
private:
Status EnsureIteratorInitialized() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (reader_) {

View File

@ -97,7 +97,10 @@ class BigtableLookupDatasetOp : public UnaryDatasetOpKernel {
return "BigtableLookupDatasetOp::Dataset";
}
bool IsStateful() const override { return true; }
Status CheckExternalState() const override {
return errors::FailedPrecondition(DebugString(),
" depends on external state.");
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
@ -174,6 +177,17 @@ class BigtableLookupDatasetOp : public UnaryDatasetOpKernel {
return Status::OK();
}
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
return errors::Unimplemented("SaveInternal is currently not supported");
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
return errors::Unimplemented(
"RestoreInternal is currently not supported");
}
private:
Status ParseRow(IteratorContext* ctx,
const ::google::cloud::bigtable::Row& row,

View File

@ -71,7 +71,10 @@ class BigtablePrefixKeyDatasetOp : public DatasetOpKernel {
BigtableTableResource* table() const { return table_; }
bool IsStateful() const override { return true; }
Status CheckExternalState() const override {
return errors::FailedPrecondition(DebugString(),
" depends on external state.");
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,

View File

@ -76,7 +76,10 @@ class BigtableRangeKeyDatasetOp : public DatasetOpKernel {
BigtableTableResource* table() const { return table_; }
bool IsStateful() const override { return true; }
Status CheckExternalState() const override {
return errors::FailedPrecondition(DebugString(),
" depends on external state.");
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,

View File

@ -89,7 +89,10 @@ class BigtableSampleKeyPairsDatasetOp : public DatasetOpKernel {
return "BigtableSampleKeyPairsDatasetOp::Dataset";
}
bool IsStateful() const override { return true; }
Status CheckExternalState() const override {
return errors::FailedPrecondition(DebugString(),
" depends on external state.");
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
@ -187,6 +190,17 @@ class BigtableSampleKeyPairsDatasetOp : public DatasetOpKernel {
return Status::OK();
}
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
return errors::Unimplemented("SaveInternal is currently not supported");
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
return errors::Unimplemented(
"RestoreInternal is currently not supported");
}
private:
mutex mu_;
size_t index_ GUARDED_BY(mu_) = 0;

View File

@ -64,7 +64,10 @@ class BigtableSampleKeysDatasetOp : public DatasetOpKernel {
BigtableTableResource* table() const { return table_; }
bool IsStateful() const override { return true; }
Status CheckExternalState() const override {
return errors::FailedPrecondition(DebugString(),
" depends on external state.");
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
@ -109,6 +112,17 @@ class BigtableSampleKeysDatasetOp : public DatasetOpKernel {
return Status::OK();
}
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
return errors::Unimplemented("SaveInternal is currently not supported");
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
return errors::Unimplemented(
"RestoreInternal is currently not supported");
}
private:
mutex mu_;
size_t index_ = 0;

View File

@ -131,7 +131,10 @@ class BigtableScanDatasetOp : public DatasetOpKernel {
BigtableTableResource* table() const { return table_; }
bool IsStateful() const override { return true; }
Status CheckExternalState() const override {
return errors::FailedPrecondition(DebugString(),
" depends on external state.");
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,

View File

@ -264,9 +264,6 @@ Status GraphDefBuilderWrapper::AddFunction(
<< " the graph. It will not be added again.";
return Status::OK();
}
if (!ctx->optimization_only()) {
TF_RETURN_IF_ERROR(EnsureFunctionIsStateless(function_name, lib_def));
}
const FunctionDef* f_def = lib_def.Find(function_name);
if (f_def == nullptr) {
return errors::InvalidArgument("Unable to find FunctionDef for ",
@ -369,29 +366,10 @@ Status StoreDatasetInVariantTensor(DatasetBase* dataset, Tensor* tensor) {
return Status::OK();
}
Status DatasetBase::Save(SerializationContext* ctx,
IteratorStateWriter* writer) const {
string serialized_graph_def;
string output_node;
GraphDefBuilder b;
DatasetGraphDefBuilder db(&b);
Node* node = nullptr;
TF_RETURN_IF_ERROR(AsGraphDefInternal(ctx, &db, &node));
output_node = node->name();
GraphDef graph_def;
TF_RETURN_IF_ERROR(b.ToGraphDef(&graph_def));
graph_def.SerializeToString(&serialized_graph_def);
TF_RETURN_IF_ERROR(
writer->WriteScalar(kDatasetGraphKey, serialized_graph_def));
TF_RETURN_IF_ERROR(
writer->WriteScalar(kDatasetGraphOutputNodeKey, output_node));
return Status::OK();
}
Status DatasetBase::DatasetGraphDefBuilder::AddInputDataset(
SerializationContext* ctx, const DatasetBase* dataset, Node** output) {
Status status = dataset->AsGraphDefInternal(ctx, this, output);
if (ctx->optimization_only() && errors::IsUnimplemented(status)) {
if (errors::IsUnimplemented(status) && !ctx->fail_if_unimplemented()) {
Tensor t(DT_VARIANT, TensorShape({}));
// `StoreDatasetInVariantTensor` will transfer ownership of `dataset`. We
// increment the refcount of `dataset` here to retain ownership.

View File

@ -201,48 +201,6 @@ class GraphDefBuilderWrapper {
private:
void AddPlaceholderInternal(const Tensor& val, Node** output);
void AddTensorInternal(const Tensor& val, Node** output);
Status EnsureFunctionIsStateless(
const string& function_name,
const FunctionLibraryDefinition& lib_def) const {
const FunctionDef* function_def = lib_def.Find(function_name);
if (!function_def) {
return errors::InvalidArgument("Unable to find FunctionDef for ",
function_name, " in registry.");
}
for (const NodeDef& node_def : function_def->node_def()) {
const OpDef* op_def;
TF_RETURN_IF_ERROR(lib_def.LookUpOpDef(node_def.op(), &op_def));
// TODO(b/65524810): Hack to allow functions to capture Dataset op
// nodes needed for FlatMap. Currently, source datasets nodes have been
// marked stateful to avoid constant folding since we do not have a
// good way of serializing them.
if (IsOpWhitelisted(op_def)) {
continue;
}
if (op_def->is_stateful()) {
return errors::InvalidArgument(
"Op[name: ", node_def.name(), ", type: ", node_def.op(), "] ",
"in function ", function_name, " is stateful. ",
"Saving stateful functions is not supported yet.");
}
}
return Status::OK();
}
// Returns whether an op has been whitelisted for use inside map_fns.
// Uses a heuristic to whitelist source dataset ops which have been
// marked stateful due to b/65524810.
// Also looks up the `op_def->name` in the global
// `WhitelistedStatefulOpRegistry`.
bool IsOpWhitelisted(const OpDef* op_def) const {
return ((absl::EndsWith(op_def->name(), "Dataset") ||
absl::EndsWith(op_def->name(), "DatasetV2")) &&
op_def->output_arg_size() == 1 &&
op_def->output_arg(0).type() == DT_VARIANT) ||
WhitelistedStatefulOpRegistry::Global()->Contains(op_def->name());
}
bool HasAttr(const string& op_type_name, const string& attr_name) const;
bool HasAttr(const OpDef* op_def, const string& attr_name) const {
@ -466,7 +424,23 @@ class SerializationContext {
public:
struct Params {
std::vector<std::pair<string, Tensor>>* input_list = nullptr; // Not owned.
bool optimization_only = false;
// Indicates whether serialization should check if the dataset depends on
// external state. If the check is enabled and external state is
// encountered, then the serialization will fail.
bool check_external_state = true;
// Indicates whether an attempt to serialize a dataset that does not
// implement serialization should result in an error. If set to `false`, the
// serialized graph will replace the dataset with a placeholder returned in
// `input_list`.
bool fail_if_unimplemented = true;
// Indicates whether (potentionally large) data tensors should be
// serialized, or replaced with a placeholder returned in `input_list`. The
// latter makes sense to do when performing data agnostic graph rewrites to
// reduce the memory usage.
bool serialize_data_tensors = true;
};
explicit SerializationContext(Params params) : params_(std::move(params)) {}
@ -475,7 +449,11 @@ class SerializationContext {
return params_.input_list;
}
bool optimization_only() { return params_.optimization_only; }
bool check_external_state() const { return params_.check_external_state; }
bool fail_if_unimplemented() const { return params_.fail_if_unimplemented; }
bool serialize_data_tensors() const { return params_.serialize_data_tensors; }
private:
Params params_;
@ -550,10 +528,6 @@ class IteratorBase {
return RestoreInternal(ctx, reader);
}
Status Restore(IteratorContext&& ctx, IteratorStateReader* reader) {
return Restore(&ctx, reader);
}
protected:
// Returns a node that models this iterator.
virtual std::shared_ptr<model::Node> CreateNode(
@ -573,12 +547,22 @@ class IteratorBase {
return input->RestoreInternal(ctx, reader);
}
// Saves the state of this iterator recursively.
// Saves the state of this iterator.
//
// This method is used to store the state of the iterator in a checkpoint.
//
// TODO(jsimsa): Make this method pure virtual once all `IteratorBase`
// implementations have an override.
virtual Status SaveInternal(IteratorStateWriter* writer) {
return errors::Unimplemented("SaveInternal");
}
// Restores the state of this iterator recursively.
// Restores the state of this iterator.
//
// This method is used to restore the state of the iterator from a checkpoint.
//
// TODO(jsimsa): Make this method pure virtual once all `IteratorBase`
// implementations have an override.
virtual Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) {
return errors::Unimplemented("RestoreInternal");
@ -718,17 +702,25 @@ class DatasetBase : public core::RefCounted {
// A human-readable debug string for this dataset.
virtual string DebugString() const = 0;
// Serializes the dataset and writes it to the `writer`.
virtual Status Save(SerializationContext* ctx,
IteratorStateWriter* writer) const;
// If the dataset is stateful it will not be possible to save its graph or
// checkpoint the state of its iterators.
//
// TODO(jsimsa): Remove this method once all `DatasetBase` implementations are
// migrated over to `CheckExternalState`.
virtual bool IsStateful() const { return false; }
// Indicates whether the dataset depends on external mutable state case in
// which case the serialization of the input pipeline graph and the
// checkpointing of the input pipeline state will not be supported.
// Indicates whether the dataset depends on any external state. If so, the
// method returns `errors::FailedPrecondition` with a message that identifies
// the external state. Otherwise, the method returns `Status::OK()`.
//
// TODO(jsimsa): Make this method pure virtual once all `DatasetBase`
// implementations have an override.
virtual bool IsStateful() const { return false; }
virtual Status CheckExternalState() const {
if (IsStateful()) {
return errors::FailedPrecondition("Dataset cannot be serialized.");
}
return Status::OK();
}
protected:
friend Status AsGraphDef(
@ -739,11 +731,22 @@ class DatasetBase : public core::RefCounted {
class DatasetGraphDefBuilder : public GraphDefBuilderWrapper {
public:
DatasetGraphDefBuilder(GraphDefBuilder* b) : GraphDefBuilderWrapper(b) {}
explicit DatasetGraphDefBuilder(GraphDefBuilder* b)
: GraphDefBuilderWrapper(b) {}
Status AddInputDataset(SerializationContext* ctx,
const DatasetBase* dataset, Node** output);
};
// Serializes the dataset into a `GraphDef`, which has two uses:
//
// 1) To perform static input pipeline optimizations, tf.data serializes the
// dataset graph, applies graph rewrites, and then deserializes the graph.
// If a subclass of `DatasetBase` does not implement this method, then it will
// be excluded from static optimizations (and so will any upstream datasets).
//
// 2) To save the dataset so that it can restore at a later point (possibly in
// different environment). If a subclass of `DatasetBase` does not implement
// this method, then this migration will not be possible.
virtual Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** node) const = 0;
@ -802,10 +805,7 @@ class DatasetBaseIterator : public IteratorBase {
bool* end_of_sequence) final;
Status Save(SerializationContext* ctx, IteratorStateWriter* writer) final {
if (params_.dataset->IsStateful()) {
return errors::FailedPrecondition(
"Saving iterator that depends on external state is not supported.");
}
TF_RETURN_IF_ERROR(params_.dataset->CheckExternalState());
return IteratorBase::Save(ctx, writer);
}

View File

@ -101,7 +101,9 @@ class BatchDatasetOp::Dataset : public DatasetBase {
return n / batch_size_ + (n % batch_size_ == 0 || drop_remainder_ ? 0 : 1);
}
bool IsStateful() const override { return input_->IsStateful(); }
Status CheckExternalState() const override {
return input_->CheckExternalState();
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,

View File

@ -449,46 +449,6 @@ TEST_P(ParameterizedBatchDatasetOpTest, Cardinality) {
EXPECT_EQ(batch_dataset->Cardinality(), test_case.expected_cardinality);
}
TEST_P(ParameterizedBatchDatasetOpTest, DatasetSave) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
std::unique_ptr<OpKernel> batch_dataset_kernel;
TF_ASSERT_OK(CreateBatchDatasetOpKernel(
test_case.parallel_copy, test_case.expected_output_dtypes,
test_case.expected_output_shapes, &batch_dataset_kernel));
DatasetBase* range_dataset;
TF_ASSERT_OK(CreateRangeDataset<int64>(
test_case.range_dataset_param.start, test_case.range_dataset_param.end,
test_case.range_dataset_param.step, "range", &range_dataset));
Tensor range_dataset_tensor(DT_VARIANT, TensorShape({}));
TF_ASSERT_OK(
StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor));
Tensor batch_size = test_case.batch_size;
Tensor drop_remainder = test_case.drop_remainder;
gtl::InlinedVector<TensorValue, 4> inputs{TensorValue(&range_dataset_tensor),
TensorValue(&batch_size),
TensorValue(&drop_remainder)};
std::unique_ptr<OpKernelContext> batch_dataset_context;
TF_ASSERT_OK(CreateBatchDatasetContext(batch_dataset_kernel.get(), &inputs,
&batch_dataset_context));
DatasetBase* batch_dataset;
TF_ASSERT_OK(CreateDataset(batch_dataset_kernel.get(),
batch_dataset_context.get(), &batch_dataset));
core::ScopedUnref scoped_unref_batch_dataset(batch_dataset);
std::unique_ptr<SerializationContext> serialization_context;
TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
VariantTensorData data;
VariantTensorDataWriter writer(&data);
TF_ASSERT_OK(batch_dataset->Save(serialization_context.get(), &writer));
TF_ASSERT_OK(writer.Flush());
}
TEST_P(ParameterizedBatchDatasetOpTest, IteratorOutputDtypes) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();

View File

@ -101,7 +101,9 @@ class CacheDatasetOp::FileDataset : public DatasetBase {
int64 Cardinality() const override { return input_->Cardinality(); }
bool IsStateful() const override { return input_->IsStateful(); }
Status CheckExternalState() const override {
return input_->CheckExternalState();
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
@ -669,7 +671,9 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase {
int64 Cardinality() const override { return input_->Cardinality(); }
bool IsStateful() const override { return input_->IsStateful(); }
Status CheckExternalState() const override {
return input_->CheckExternalState();
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
@ -971,7 +975,10 @@ class CacheDatasetOp::MemoryDatasetV2 : public CacheDatasetOp::MemoryDataset {
MemoryCache* cache, const Tensor& resource_handle)
: MemoryDataset(ctx, input, cache), resource_handle_(resource_handle) {}
bool IsStateful() const override { return true; }
Status CheckExternalState() const override {
return errors::FailedPrecondition(DebugString(),
" depends on memory cache resource.");
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,

View File

@ -353,39 +353,6 @@ TEST_P(ParameterizedCacheDatasetOpTest, Cardinality) {
EXPECT_EQ(cache_dataset->Cardinality(), test_case.expected_cardinality);
}
TEST_P(ParameterizedCacheDatasetOpTest, DatasetSave) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
std::unique_ptr<OpKernel> cache_dataset_kernel;
TF_ASSERT_OK(CreateCacheDatasetOpKernel(test_case.expected_output_dtypes,
test_case.expected_output_shapes,
&cache_dataset_kernel));
Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({}));
std::vector<Tensor> inputs_for_tensor_slice_dataset = test_case.input_tensors;
TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset,
&tensor_slice_dataset_tensor));
Tensor file_name = CreateTensor<string>(TensorShape{}, {test_case.file_name});
gtl::InlinedVector<TensorValue, 4> inputs(
{TensorValue(&tensor_slice_dataset_tensor), TensorValue(&file_name)});
std::unique_ptr<OpKernelContext> cache_dataset_context;
TF_ASSERT_OK(CreateCacheDatasetContext(cache_dataset_kernel.get(), &inputs,
&cache_dataset_context));
DatasetBase* cache_dataset;
TF_ASSERT_OK(CreateDataset(cache_dataset_kernel.get(),
cache_dataset_context.get(), &cache_dataset));
core::ScopedUnref scoped_unref(cache_dataset);
std::unique_ptr<SerializationContext> serialization_context;
TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
VariantTensorData data;
VariantTensorDataWriter writer(&data);
TF_ASSERT_OK(cache_dataset->Save(serialization_context.get(), &writer));
TF_ASSERT_OK(writer.Flush());
}
TEST_P(ParameterizedCacheDatasetOpTest, IteratorOutputShapes) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();

View File

@ -210,17 +210,19 @@ Status CreateFunctionLibraryDefinition(
return (*result)->CopyFunctionDefFrom(func_name, *lib_def);
}
bool IsNodeStateful(const FunctionLibraryDefinition& library,
const NodeDef& node);
Status IsNodeStateful(const FunctionLibraryDefinition& library,
const NodeDef& node);
bool IsFunctionStateful(const FunctionLibraryDefinition& library,
const FunctionDef& function_def) {
if (!function_def.signature().is_stateful()) return false;
Status IsFunctionStateful(const FunctionLibraryDefinition& library,
const FunctionDef& function_def) {
if (!function_def.signature().is_stateful()) {
return Status::OK();
}
for (const NodeDef& node_def : function_def.node_def()) {
if (IsNodeStateful(library, node_def)) return true;
TF_RETURN_IF_ERROR(IsNodeStateful(library, node_def));
}
return false;
return Status::OK();
}
// Returns whether an op has been whitelisted as stateless. Uses a heuristic to
@ -228,27 +230,23 @@ bool IsFunctionStateful(const FunctionLibraryDefinition& library,
// b/65524810. Also looks up the `op_def->name` in the global
// `WhitelistedStatefulOpRegistry`.
bool IsOpWhitelisted(const OpDef* op_def) {
return ((absl::EndsWith(op_def->name(), "Dataset") ||
absl::EndsWith(op_def->name(), "DatasetV2")) &&
op_def->output_arg_size() == 1 &&
op_def->output_arg(0).type() == DT_VARIANT) ||
return (op_def->output_arg_size() == 1 &&
op_def->output_arg(0).type() == DT_VARIANT &&
(absl::EndsWith(op_def->name(), "Dataset") ||
absl::EndsWith(op_def->name(), "DatasetV2"))) ||
WhitelistedStatefulOpRegistry::Global()->Contains(op_def->name());
}
bool IsNodeStateful(const FunctionLibraryDefinition& library,
const NodeDef& node) {
Status IsNodeStateful(const FunctionLibraryDefinition& library,
const NodeDef& node) {
const OpDef* op_def;
Status s = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
if (!s.ok()) {
return false;
}
if (IsOpWhitelisted(op_def)) return false;
if (!op_def->is_stateful()) return false;
if (op_def->name() == "Assert") {
return false;
// TODO(jsimsa): Fix C++ unit tests so that we do not have to ignore
// `LookUpOpDef` errors here.
if (!OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok() ||
IsOpWhitelisted(op_def) || !op_def->is_stateful() ||
op_def->name() == "Assert") {
return Status::OK();
}
if (op_def->name() == "If") {
@ -256,10 +254,13 @@ bool IsNodeStateful(const FunctionLibraryDefinition& library,
library.Find(node.attr().at("then_branch").func().name());
const FunctionDef* else_func =
library.Find(node.attr().at("else_branch").func().name());
if ((then_func != nullptr && !IsFunctionStateful(library, *then_func)) &&
(else_func != nullptr && !IsFunctionStateful(library, *else_func))) {
return false;
if (then_func != nullptr) {
TF_RETURN_IF_ERROR(IsFunctionStateful(library, *then_func));
}
if (else_func != nullptr) {
TF_RETURN_IF_ERROR(IsFunctionStateful(library, *else_func));
}
return Status::OK();
}
if (op_def->name() == "While") {
@ -267,12 +268,16 @@ bool IsNodeStateful(const FunctionLibraryDefinition& library,
library.Find(node.attr().at("cond").func().name());
const FunctionDef* body_func =
library.Find(node.attr().at("body").func().name());
if ((cond_func != nullptr && !IsFunctionStateful(library, *cond_func)) &&
(body_func != nullptr && !IsFunctionStateful(library, *body_func))) {
return false;
if (cond_func != nullptr) {
TF_RETURN_IF_ERROR(IsFunctionStateful(library, *cond_func));
}
if (body_func != nullptr) {
TF_RETURN_IF_ERROR(IsFunctionStateful(library, *body_func));
}
return Status::OK();
}
return true;
return errors::FailedPrecondition(op_def->name(), " is stateful.");
}
} // namespace
@ -482,13 +487,14 @@ Status CapturedFunction::Instantiate(
return Status::OK();
}
bool CapturedFunction::IsStateful() const {
bool CapturedFunction::IsStateful() const { return !CheckExternalState().ok(); }
Status CapturedFunction::CheckExternalState() const {
for (const auto& name : lib_def()->ListFunctionNames()) {
if (IsFunctionStateful(*lib_def(), *(lib_def()->Find(name)))) {
return true;
}
TF_RETURN_IF_ERROR(
IsFunctionStateful(*lib_def(), *(lib_def()->Find(name))));
}
return false;
return Status::OK();
}
namespace {

View File

@ -205,8 +205,14 @@ class CapturedFunction {
instantiated_captured_function);
// Determines whether the captured function is stateful.
//
// TODO(jsimsa): Remove this method once all users of `CapturedFunction`
// migrate to `CheckExternalState`.
bool IsStateful() const;
// Determines whether the captured function is stateful.
Status CheckExternalState() const;
// Returns the additional captured inputs that will be passed to the function.
const std::vector<Tensor>& captured_inputs() const {
return captured_inputs_;

View File

@ -85,8 +85,9 @@ class ConcatenateDatasetOp::Dataset : public DatasetBase {
return n1 + n2;
}
bool IsStateful() const override {
return input_->IsStateful() || to_concatenate_->IsStateful();
Status CheckExternalState() const override {
TF_RETURN_IF_ERROR(input_->CheckExternalState());
return to_concatenate_->CheckExternalState();
}
protected:

View File

@ -357,39 +357,6 @@ TEST_P(ParameterizedConcatenateDatasetOpTest, Cardinality) {
EXPECT_EQ(concatenate_dataset->Cardinality(), test_case.expected_cardinality);
}
TEST_F(ConcatenateDatasetOpTest, DatasetSave) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestCase &test_case = SameShapeTestCase();
std::vector<Tensor> tensor_slice_dataset_tensors;
TF_ASSERT_OK(CreateTensorSliceDatasetTensors(test_case.input_tensors,
&tensor_slice_dataset_tensors));
gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &tensor : tensor_slice_dataset_tensors) {
inputs.emplace_back(&tensor);
}
std::unique_ptr<OpKernel> dataset_kernel;
TF_ASSERT_OK(CreateConcatenateDatasetKernel(test_case.expected_output_dtypes,
test_case.expected_output_shapes,
&dataset_kernel));
std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
TF_ASSERT_OK(CreateConcatenateDatasetContext(dataset_kernel.get(), &inputs,
&dataset_kernel_ctx));
DatasetBase *concatenate_dataset;
TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(),
&concatenate_dataset));
core::ScopedUnref scoped_unref(concatenate_dataset);
std::unique_ptr<SerializationContext> serialization_ctx;
TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
VariantTensorData data;
VariantTensorDataWriter writer(&data);
TF_ASSERT_OK(concatenate_dataset->Save(serialization_ctx.get(), &writer));
TF_ASSERT_OK(writer.Flush());
}
TEST_P(ParameterizedConcatenateDatasetOpTest, IteratorOutputDtypes) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));

View File

@ -593,16 +593,6 @@ Status DatasetOpsTestBase::CheckDatasetCardinality(const DatasetBase& dataset,
return Status::OK();
}
Status DatasetOpsTestBase::CheckDatasetSave(const DatasetBase& dataset) {
std::unique_ptr<SerializationContext> serialization_context;
TF_RETURN_IF_ERROR(CreateSerializationContext(&serialization_context));
VariantTensorData data;
VariantTensorDataWriter writer(&data);
TF_EXPECT_OK(dataset.Save(serialization_context.get(), &writer));
TF_RETURN_IF_ERROR(writer.Flush());
return Status::OK();
}
Status DatasetOpsTestBase::CheckDatasetIsStateful(const DatasetBase& dataset,
bool expected_stateful) {
EXPECT_EQ(dataset.IsStateful(), expected_stateful);

View File

@ -317,9 +317,6 @@ class DatasetOpsTestBase : public ::testing::Test {
Status CheckDatasetCardinality(const DatasetBase& dataset,
int64 expected_cardinality);
// Checks `DatasetBase::Save()`.
Status CheckDatasetSave(const DatasetBase& dataset);
// Checks `DatasetBase::IsStateful()`.
Status CheckDatasetIsStateful(const DatasetBase& dataset,
bool expected_stateful);

View File

@ -375,13 +375,16 @@ uint64 HashSubgraphFunctionImpl(
Status AsGraphDef(OpKernelContext* ctx, const DatasetBase* dataset,
SerializationContext&& serialization_ctx,
GraphDef* graph_def) {
if (serialization_ctx.check_external_state()) {
TF_RETURN_IF_ERROR(dataset->CheckExternalState());
}
GraphDefBuilder b;
DatasetBase::DatasetGraphDefBuilder db(&b);
Node* output_node = nullptr;
TF_RETURN_IF_ERROR(
db.AddInputDataset(&serialization_ctx, dataset, &output_node));
// Insert a purely symbolic _Retval node to indicate to consumers which Tensor
// represents this Dataset.
// Insert a purely symbolic _Retval node to indicate to consumers which node
// represents `dataset`.
ops::UnaryOp("_Retval", output_node,
b.opts()
.WithName("dataset")
@ -415,7 +418,9 @@ Status RewriteDataset(OpKernelContext* ctx, const DatasetBase* input,
SerializationContext::Params params;
std::vector<std::pair<string, Tensor>> input_list;
params.input_list = &input_list;
params.optimization_only = true;
params.check_external_state = false;
params.fail_if_unimplemented = false;
params.serialize_data_tensors = false;
SerializationContext serialization_ctx(params);
GraphDef graph_def;
TF_RETURN_IF_ERROR(

View File

@ -63,7 +63,9 @@ class AssertNextDatasetOp::Dataset : public DatasetBase {
int64 Cardinality() const override { return input_->Cardinality(); }
bool IsStateful() const override { return input_->IsStateful(); }
Status CheckExternalState() const override {
return input_->CheckExternalState();
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,

View File

@ -376,43 +376,6 @@ TEST_P(ParameterizedAssertNextDatasetOpTest, Cardinality) {
EXPECT_EQ(assert_next_dataset->Cardinality(), test_case.expected_cardinality);
}
TEST_P(ParameterizedAssertNextDatasetOpTest, DatasetSave) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
Tensor range_and_take_dataset_tensor;
TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
test_case.take_dataset_params,
&range_and_take_dataset_tensor));
std::unique_ptr<OpKernel> assert_next_dataset_kernel;
TF_ASSERT_OK(CreateAssertNextDatasetOpKernel(test_case.expected_output_dtypes,
test_case.expected_output_shapes,
&assert_next_dataset_kernel));
Tensor transformations = test_case.transformations;
gtl::InlinedVector<TensorValue, 4> inputs(
{TensorValue(&range_and_take_dataset_tensor),
TensorValue(&transformations)});
std::unique_ptr<OpKernelContext> assert_next_dataset_context;
TF_ASSERT_OK(CreateAssertNextDatasetContext(
assert_next_dataset_kernel.get(), &inputs, &assert_next_dataset_context));
DatasetBase* assert_next_dataset;
TF_ASSERT_OK(CreateDataset(assert_next_dataset_kernel.get(),
assert_next_dataset_context.get(),
&assert_next_dataset));
core::ScopedUnref scoped_unref(assert_next_dataset);
std::unique_ptr<SerializationContext> serialization_context;
TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
VariantTensorData data;
VariantTensorDataWriter writer(&data);
TF_ASSERT_OK(assert_next_dataset->Save(serialization_context.get(), &writer));
TF_ASSERT_OK(writer.Flush());
}
TEST_P(ParameterizedAssertNextDatasetOpTest, IteratorOutputDtypes) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();

View File

@ -242,13 +242,11 @@ class ChooseFastestBranchDatasetOp : public UnaryDatasetOpKernel {
return static_cast<double>(n) * ratio_numerator_ / ratio_denominator_;
}
bool IsStateful() const override {
Status CheckExternalState() const override {
for (const auto& captured_func : captured_funcs_) {
if (captured_func->IsStateful()) {
return true;
}
TF_RETURN_IF_ERROR(captured_func->CheckExternalState());
}
return input_->IsStateful();
return input_->CheckExternalState();
}
protected:

View File

@ -158,13 +158,11 @@ class ChooseFastestDatasetOp : public DatasetOpKernel {
int64 Cardinality() const override { return cardinality_; }
bool IsStateful() const override {
Status CheckExternalState() const override {
for (const auto& input : inputs_) {
if (input->IsStateful()) {
return true;
}
TF_RETURN_IF_ERROR(input->CheckExternalState());
}
return false;
return Status::OK();
}
protected:

View File

@ -170,6 +170,8 @@ class CSVDatasetOp : public DatasetOpKernel {
string DebugString() const override { return "CSVDatasetOp::Dataset"; }
Status CheckExternalState() const override { return Status::OK(); }
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,

View File

@ -120,7 +120,9 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel {
return n / batch_size_ + (n % batch_size_ == 0 ? 0 : 1);
}
bool IsStateful() const override { return input_->IsStateful(); }
Status CheckExternalState() const override {
return input_->CheckExternalState();
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,

View File

@ -107,13 +107,11 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel {
return strings::StrCat("DirectedInterleaveDatasetOp::Dataset");
}
bool IsStateful() const override {
Status CheckExternalState() const override {
for (const auto& input : data_inputs_) {
if (input->IsStateful()) {
return true;
}
TF_RETURN_IF_ERROR(input->CheckExternalState());
}
return selector_input_->IsStateful();
return selector_input_->CheckExternalState();
}
protected:

View File

@ -112,11 +112,12 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
return "GroupByReducerDatasetOp::Dataset";
}
bool IsStateful() const override {
return captured_key_func_->IsStateful() ||
captured_init_func_->IsStateful() ||
captured_reduce_func_->IsStateful() ||
captured_finalize_func_->IsStateful() || input_->IsStateful();
Status CheckExternalState() const override {
TF_RETURN_IF_ERROR(captured_key_func_->CheckExternalState());
TF_RETURN_IF_ERROR(captured_init_func_->CheckExternalState());
TF_RETURN_IF_ERROR(captured_reduce_func_->CheckExternalState());
TF_RETURN_IF_ERROR(captured_finalize_func_->CheckExternalState());
return input_->CheckExternalState();
}
protected:

View File

@ -108,10 +108,11 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
return "GroupByWindowDatasetOp::Dataset";
}
bool IsStateful() const override {
return captured_key_func_->IsStateful() ||
captured_reduce_func_->IsStateful() ||
captured_window_size_func_->IsStateful() || input_->IsStateful();
Status CheckExternalState() const override {
TF_RETURN_IF_ERROR(captured_key_func_->CheckExternalState());
TF_RETURN_IF_ERROR(captured_reduce_func_->CheckExternalState());
TF_RETURN_IF_ERROR(captured_window_size_func_->CheckExternalState());
return input_->CheckExternalState();
}
protected:

View File

@ -60,7 +60,9 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel {
int64 Cardinality() const override { return input_->Cardinality(); }
bool IsStateful() const override { return input_->IsStateful(); }
Status CheckExternalState() const override {
return input_->CheckExternalState();
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,

View File

@ -70,6 +70,8 @@ class LMDBDatasetOp : public DatasetOpKernel {
string DebugString() const override { return "LMDBDatasetOp::Dataset"; }
Status CheckExternalState() const override { return Status::OK(); }
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,

View File

@ -143,8 +143,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
(n % batch_size_ == 0 || drop_remainder_ ? 0 : 1);
}
bool IsStateful() const override {
return captured_func_->IsStateful() || input_->IsStateful();
Status CheckExternalState() const override {
TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
return input_->CheckExternalState();
}
protected:

View File

@ -81,6 +81,8 @@ class MatchingFilesDatasetOp : public DatasetOpKernel {
return "MatchingFilesDatasetOp::Dataset";
}
Status CheckExternalState() const override { return Status::OK(); }
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,

View File

@ -69,13 +69,16 @@ class NonSerializableDatasetOp : public UnaryDatasetOpKernel {
return "NonSerializableDatasetOp::Dataset";
}
bool IsStateful() const override { return input_->IsStateful(); }
Status CheckExternalState() const override {
return input_->CheckExternalState();
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
return errors::Unimplemented(DebugString(), "::AsGraphDefInternal");
return errors::Unimplemented(DebugString(),
" does not support serialization.");
}
int64 Cardinality() const override { return input_->Cardinality(); }

View File

@ -121,8 +121,9 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
return name_utils::DatasetDebugString(kDatasetType);
}
bool IsStateful() const override {
return captured_func_->IsStateful() || input_->IsStateful();
Status CheckExternalState() const override {
TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
return input_->CheckExternalState();
}
protected:

View File

@ -569,46 +569,6 @@ TEST_P(ParameterizedParallelInterleaveDatasetOpTest, Cardinality) {
test_case.expected_cardinality);
}
TEST_P(ParameterizedParallelInterleaveDatasetOpTest, DatasetSave) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num));
std::unique_ptr<OpKernel> parallel_interleave_dataset_kernel;
TF_ASSERT_OK(CreateParallelInterleaveDatasetKernel(
test_case.func, test_case.expected_output_dtypes,
test_case.expected_output_shapes, &parallel_interleave_dataset_kernel));
Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({}));
std::vector<Tensor> inputs_for_tensor_slice_dataset = test_case.input_tensors;
TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset,
&tensor_slice_dataset_tensor));
gtl::InlinedVector<TensorValue, 4> inputs(
{TensorValue(&tensor_slice_dataset_tensor),
TensorValue(&test_case.cycle_length),
TensorValue(&test_case.block_length), TensorValue(&test_case.sloppy),
TensorValue(&test_case.buffer_output_elements),
TensorValue(&test_case.prefetch_input_elements)});
std::unique_ptr<OpKernelContext> parallel_interleave_dataset_context;
TF_ASSERT_OK(CreateParallelInterleaveDatasetContext(
parallel_interleave_dataset_kernel.get(), &inputs,
&parallel_interleave_dataset_context));
DatasetBase* parallel_interleave_dataset;
TF_ASSERT_OK(CreateDataset(parallel_interleave_dataset_kernel.get(),
parallel_interleave_dataset_context.get(),
&parallel_interleave_dataset));
core::ScopedUnref scoped_unref_dataset(parallel_interleave_dataset);
std::unique_ptr<SerializationContext> serialization_ctx;
TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
VariantTensorData data;
VariantTensorDataWriter writer(&data);
TF_ASSERT_OK(
parallel_interleave_dataset->Save(serialization_ctx.get(), &writer));
TF_ASSERT_OK(writer.Flush());
}
TEST_P(ParameterizedParallelInterleaveDatasetOpTest, IteratorOutputDtypes) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();

View File

@ -206,7 +206,9 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
int64 Cardinality() const override { return input_->Cardinality(); }
bool IsStateful() const override { return input_->IsStateful(); }
Status CheckExternalState() const override {
return input_->CheckExternalState();
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,

View File

@ -76,6 +76,8 @@ class RandomDatasetOp : public DatasetOpKernel {
int64 Cardinality() const override { return kInfiniteCardinality; }
Status CheckExternalState() const override { return Status::OK(); }
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,

View File

@ -79,7 +79,9 @@ class SamplingDatasetOp : public UnaryDatasetOpKernel {
string DebugString() const override { return "SamplingDatasetOp::Dataset"; }
bool IsStateful() const override { return input_->IsStateful(); }
Status CheckExternalState() const override {
return input_->CheckExternalState();
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,

View File

@ -102,8 +102,9 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
int64 Cardinality() const override { return input_->Cardinality(); }
bool IsStateful() const override {
return captured_func_->IsStateful() || input_->IsStateful();
Status CheckExternalState() const override {
TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
return input_->CheckExternalState();
}
protected:

View File

@ -138,7 +138,9 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
int64 Cardinality() const override { return input_->Cardinality(); }
bool IsStateful() const override { return input_->IsStateful(); }
Status CheckExternalState() const override {
return input_->CheckExternalState();
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,

View File

@ -67,7 +67,9 @@ class SleepDatasetOp : public UnaryDatasetOpKernel {
int64 Cardinality() const override { return input_->Cardinality(); }
bool IsStateful() const override { return input_->IsStateful(); }
Status CheckExternalState() const override {
return input_->CheckExternalState();
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,

View File

@ -108,7 +108,9 @@ class SlidingWindowDatasetOp : public UnaryDatasetOpKernel {
return n / window_shift_;
}
bool IsStateful() const override { return input_->IsStateful(); }
Status CheckExternalState() const override {
return input_->CheckExternalState();
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,

View File

@ -314,7 +314,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
SerializationContext::Params params;
std::vector<std::pair<string, Tensor>> input_list;
params.input_list = &input_list;
params.optimization_only = true;
params.check_external_state = false;
GraphDef graph_def;
OP_REQUIRES_OK(
@ -376,7 +376,9 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
int64 Cardinality() const override { return input_->Cardinality(); }
bool IsStateful() const override { return input_->IsStateful(); }
Status CheckExternalState() const override {
return input_->CheckExternalState();
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,

View File

@ -103,6 +103,8 @@ class SqlDatasetOp : public DatasetOpKernel {
string DebugString() const override { return "SqlDatasetOp::Dataset"; }
Status CheckExternalState() const override { return Status::OK(); }
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,

View File

@ -78,6 +78,10 @@ class LatencyStatsDatasetOp : public UnaryDatasetOpKernel {
int64 Cardinality() const override { return input_->Cardinality(); }
Status CheckExternalState() const override {
return input_->CheckExternalState();
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
@ -189,6 +193,10 @@ class BytesProducedStatsDatasetOp : public UnaryDatasetOpKernel {
int64 Cardinality() const override { return input_->Cardinality(); }
Status CheckExternalState() const override {
return input_->CheckExternalState();
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,

View File

@ -84,8 +84,9 @@ class TakeWhileDatasetOp : public UnaryDatasetOpKernel {
int64 Cardinality() const override { return kUnknownCardinality; }
bool IsStateful() const override {
return captured_func_->IsStateful() || input_->IsStateful();
Status CheckExternalState() const override {
TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
return input_->CheckExternalState();
}
protected:

View File

@ -172,7 +172,9 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel {
int64 Cardinality() const override { return input_->Cardinality(); }
bool IsStateful() const override { return input_->IsStateful(); }
Status CheckExternalState() const override {
return input_->CheckExternalState();
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
@ -281,7 +283,9 @@ class MaxIntraOpParallelismDatasetOp : public UnaryDatasetOpKernel {
int64 Cardinality() const override { return input_->Cardinality(); }
bool IsStateful() const override { return input_->IsStateful(); }
Status CheckExternalState() const override {
return input_->CheckExternalState();
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
@ -383,7 +387,9 @@ class PrivateThreadPoolDatasetOp : public UnaryDatasetOpKernel {
int64 Cardinality() const override { return input_->Cardinality(); }
bool IsStateful() const override { return input_->IsStateful(); }
Status CheckExternalState() const override {
return input_->CheckExternalState();
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,

View File

@ -69,7 +69,9 @@ class UnbatchDatasetOp : public UnaryDatasetOpKernel {
string DebugString() const override { return "UnbatchDatasetOp::Dataset"; }
bool IsStateful() const override { return input_->IsStateful(); }
Status CheckExternalState() const override {
return input_->CheckExternalState();
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,

View File

@ -72,7 +72,9 @@ class UniqueDatasetOp : public UnaryDatasetOpKernel {
return strings::StrCat("UniqueDatasetOp::Dataset");
}
bool IsStateful() const override { return input_->IsStateful(); }
Status CheckExternalState() const override {
return input_->CheckExternalState();
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,

View File

@ -75,8 +75,9 @@ class FilterDatasetOp::Dataset : public DatasetBase {
return name_utils::DatasetDebugString(kDatasetType);
}
bool IsStateful() const override {
return captured_func_->IsStateful() || input_->IsStateful();
Status CheckExternalState() const override {
TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
return input_->CheckExternalState();
}
protected:

View File

@ -346,39 +346,6 @@ TEST_P(ParameterizedFilterDatasetOpTest, Cardinality) {
EXPECT_EQ(filter_dataset->Cardinality(), test_case.expected_cardinality);
}
TEST_P(ParameterizedFilterDatasetOpTest, DatasetSave) {
int thread_num = 2, cpu_num = 2;
const TestCase &test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num));
std::unique_ptr<OpKernel> filter_dataset_kernel;
TF_ASSERT_OK(CreateFilterDatasetKernel(
test_case.func, test_case.expected_output_dtypes,
test_case.expected_output_shapes, &filter_dataset_kernel));
Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({}));
std::vector<Tensor> inputs_for_tensor_slice_dataset = test_case.input_tensors;
TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset,
&tensor_slice_dataset_tensor));
gtl::InlinedVector<TensorValue, 4> inputs(
{TensorValue(&tensor_slice_dataset_tensor)});
std::unique_ptr<OpKernelContext> filter_dataset_context;
TF_ASSERT_OK(CreateFilterDatasetContext(filter_dataset_kernel.get(), &inputs,
&filter_dataset_context));
DatasetBase *filter_dataset;
TF_ASSERT_OK(CreateDataset(filter_dataset_kernel.get(),
filter_dataset_context.get(), &filter_dataset));
core::ScopedUnref scoped_unref(filter_dataset);
std::unique_ptr<SerializationContext> serialization_ctx;
TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
VariantTensorData data;
VariantTensorDataWriter writer(&data);
TF_ASSERT_OK(filter_dataset->Save(serialization_ctx.get(), &writer));
TF_ASSERT_OK(writer.Flush());
}
TEST_P(ParameterizedFilterDatasetOpTest, IteratorOutputDtypes) {
int thread_num = 2, cpu_num = 2;
const TestCase &test_case = GetParam();

View File

@ -93,6 +93,8 @@ class FixedLengthRecordDatasetOp::Dataset : public DatasetBase {
return name_utils::DatasetDebugString(kDatasetType, params);
}
Status CheckExternalState() const override { return Status::OK(); }
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,

View File

@ -452,56 +452,6 @@ TEST_P(ParameterizedFixedLengthRecordDatasetOpTest, Cardinality) {
test_case.expected_cardinality);
}
TEST_P(ParameterizedFixedLengthRecordDatasetOpTest, DatasetSave) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
TF_ASSERT_OK(CreateTestFiles(test_case));
std::unique_ptr<OpKernel> fixed_length_record_dataset_kernel;
TF_ASSERT_OK(CreateFixedLengthRecordDatasetOpKernel(
&fixed_length_record_dataset_kernel));
int64 num_files = test_case.filenames.size();
Tensor filenames =
CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
Tensor header_bytes =
CreateTensor<int64>(TensorShape({}), {test_case.header_bytes});
Tensor record_bytes =
CreateTensor<int64>(TensorShape({}), {test_case.record_bytes});
Tensor footer_bytes =
CreateTensor<int64>(TensorShape({}), {test_case.footer_bytes});
Tensor buffer_size =
CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
Tensor compression_type = CreateTensor<string>(
TensorShape({}), {ToString(test_case.compression_type)});
gtl::InlinedVector<TensorValue, 4> inputs{
TensorValue(&filenames), TensorValue(&header_bytes),
TensorValue(&record_bytes), TensorValue(&footer_bytes),
TensorValue(&buffer_size), TensorValue(&compression_type),
};
std::unique_ptr<OpKernelContext> fixed_length_record_dataset_context;
TF_ASSERT_OK(CreateFixedLengthRecordDatasetContext(
fixed_length_record_dataset_kernel.get(), &inputs,
&fixed_length_record_dataset_context));
DatasetBase* fixed_length_record_dataset;
TF_ASSERT_OK(CreateDataset(fixed_length_record_dataset_kernel.get(),
fixed_length_record_dataset_context.get(),
&fixed_length_record_dataset));
core::ScopedUnref scoped_unref(fixed_length_record_dataset);
std::unique_ptr<SerializationContext> serialization_context;
TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
VariantTensorData data;
VariantTensorDataWriter writer(&data);
TF_ASSERT_OK(
fixed_length_record_dataset->Save(serialization_context.get(), &writer));
TF_ASSERT_OK(writer.Flush());
}
TEST_P(ParameterizedFixedLengthRecordDatasetOpTest, IteratorOutputDtypes) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();

View File

@ -75,8 +75,9 @@ class FlatMapDatasetOp::Dataset : public DatasetBase {
return name_utils::DatasetDebugString(kDatasetType);
}
bool IsStateful() const override {
return captured_func_->IsStateful() || input_->IsStateful();
Status CheckExternalState() const override {
TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
return input_->CheckExternalState();
}
protected:

View File

@ -364,41 +364,6 @@ TEST_P(ParameterizedFlatMapDatasetOpTest, Cardinality) {
EXPECT_EQ(flat_map_dataset->Cardinality(), test_case.expected_cardinality);
}
TEST_F(FlatMapDatasetOpTest, DatasetSave) {
int thread_num = 2, cpu_num = 2;
const TestCase &test_case = MakeTensorSliceDatasetFuncTestCase();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num));
std::unique_ptr<OpKernel> flat_map_dataset_kernel;
TF_ASSERT_OK(CreateFlatMapDatasetKernel(
test_case.func, test_case.expected_output_dtypes,
test_case.expected_output_shapes, &flat_map_dataset_kernel));
Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({}));
std::vector<Tensor> inputs_for_tensor_slice_dataset = test_case.input_tensors;
TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset,
&tensor_slice_dataset_tensor));
gtl::InlinedVector<TensorValue, 4> inputs(
{TensorValue(&tensor_slice_dataset_tensor)});
std::unique_ptr<OpKernelContext> flat_map_dataset_context;
TF_ASSERT_OK(CreateFlatMapDatasetContext(flat_map_dataset_kernel.get(),
&inputs, &flat_map_dataset_context));
DatasetBase *flat_map_dataset;
TF_ASSERT_OK(CreateDataset(flat_map_dataset_kernel.get(),
flat_map_dataset_context.get(),
&flat_map_dataset));
core::ScopedUnref scoped_unref(flat_map_dataset);
std::unique_ptr<SerializationContext> serialization_ctx;
TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
VariantTensorData data;
VariantTensorDataWriter writer(&data);
TF_ASSERT_OK(flat_map_dataset->Save(serialization_ctx.get(), &writer));
TF_ASSERT_OK(writer.Flush());
}
TEST_P(ParameterizedFlatMapDatasetOpTest, IteratorOutputDtypes) {
int thread_num = 2, cpu_num = 2;
const TestCase &test_case = GetParam();

View File

@ -74,17 +74,18 @@ class GeneratorDatasetOp::Dataset : public DatasetBase {
return name_utils::DatasetDebugString(kDatasetType);
}
bool IsStateful() const override {
return init_func_->IsStateful() || next_func_->IsStateful() ||
finalize_func_->IsStateful();
Status CheckExternalState() const override {
TF_RETURN_IF_ERROR(init_func_->CheckExternalState());
TF_RETURN_IF_ERROR(next_func_->CheckExternalState());
return finalize_func_->CheckExternalState();
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
return errors::Unimplemented("%s does not support serialization",
DebugString());
return errors::Unimplemented(DebugString(),
" does not support serialization");
}
private:

View File

@ -81,8 +81,9 @@ class InterleaveDatasetOp::Dataset : public DatasetBase {
return name_utils::DatasetDebugString(kDatasetType);
}
bool IsStateful() const override {
return captured_func_->IsStateful() || input_->IsStateful();
Status CheckExternalState() const override {
TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
return input_->CheckExternalState();
}
protected:

View File

@ -570,43 +570,6 @@ TEST_P(ParameterizedInterleaveDatasetOpTest, Cardinality) {
EXPECT_EQ(interleave_dataset->Cardinality(), test_case.expected_cardinality);
}
TEST_P(ParameterizedInterleaveDatasetOpTest, DatasetSave) {
int thread_num = 2, cpu_num = 2;
const TestCase &test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num));
std::unique_ptr<OpKernel> interleave_dataset_kernel;
TF_ASSERT_OK(CreateInterleaveDatasetKernel(
test_case.func, test_case.expected_output_dtypes,
test_case.expected_output_shapes, &interleave_dataset_kernel));
Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({}));
std::vector<Tensor> inputs_for_tensor_slice_dataset = test_case.input_tensors;
TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset,
&tensor_slice_dataset_tensor));
Tensor cycle_length = test_case.cycle_length;
Tensor block_length = test_case.block_length;
gtl::InlinedVector<TensorValue, 4> inputs(
{TensorValue(&tensor_slice_dataset_tensor), TensorValue(&cycle_length),
TensorValue(&block_length)});
std::unique_ptr<OpKernelContext> interleave_dataset_context;
TF_ASSERT_OK(CreateInterleaveDatasetContext(
interleave_dataset_kernel.get(), &inputs, &interleave_dataset_context));
DatasetBase *interleave_dataset;
TF_ASSERT_OK(CreateDataset(interleave_dataset_kernel.get(),
interleave_dataset_context.get(),
&interleave_dataset));
core::ScopedUnref scoped_unref(interleave_dataset);
std::unique_ptr<SerializationContext> serialization_ctx;
TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
VariantTensorData data;
VariantTensorDataWriter writer(&data);
TF_ASSERT_OK(interleave_dataset->Save(serialization_ctx.get(), &writer));
TF_ASSERT_OK(writer.Flush());
}
TEST_P(ParameterizedInterleaveDatasetOpTest, IteratorOutputDtypes) {
int thread_num = 2, cpu_num = 2;
const TestCase &test_case = GetParam();

View File

@ -126,8 +126,8 @@ Status IteratorResource::Restore(OpKernelContext* ctx,
params.cancellation_manager,
&deregister_fn));
auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
return captured_state->iterator->Restore(IteratorContext(std::move(params)),
reader);
IteratorContext iter_ctx(std::move(params));
return captured_state->iterator->Restore(&iter_ctx, reader);
}
return errors::FailedPrecondition(
"Restore() failed because the iterator has not been initialized. Ensure "

View File

@ -74,8 +74,9 @@ class MapDatasetOp::Dataset : public DatasetBase {
int64 Cardinality() const override { return input_->Cardinality(); }
bool IsStateful() const override {
return captured_func_->IsStateful() || input_->IsStateful();
Status CheckExternalState() const override {
TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
return input_->CheckExternalState();
}
protected:

View File

@ -341,43 +341,6 @@ TEST_P(ParameterizedMapDatasetOpTest, Cardinality) {
EXPECT_EQ(map_dataset->Cardinality(), test_case.expected_cardinality);
}
TEST_P(ParameterizedMapDatasetOpTest, DatasetSave) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num));
DatasetBase* range_dataset;
TF_ASSERT_OK(CreateRangeDataset<int64>(
test_case.start, test_case.end, test_case.step, "range", &range_dataset));
Tensor range_dataset_tensor(DT_VARIANT, TensorShape({}));
// The ownership of range_dataset is transferred to DatasetVariantWrapper,
// which will handle the release of memory.
TF_ASSERT_OK(
StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor));
gtl::InlinedVector<TensorValue, 4> map_dataset_inputs;
map_dataset_inputs.emplace_back(&range_dataset_tensor);
std::unique_ptr<OpKernel> map_dataset_kernel;
TF_ASSERT_OK(CreateMapDatasetOpKernel(
test_case.func, test_case.expected_output_dtypes,
test_case.expected_output_shapes, &map_dataset_kernel));
std::unique_ptr<OpKernelContext> map_dataset_context;
TF_ASSERT_OK(CreateMapDatasetContext(
map_dataset_kernel.get(), &map_dataset_inputs, &map_dataset_context));
DatasetBase* map_dataset;
TF_ASSERT_OK(CreateDataset(map_dataset_kernel.get(),
map_dataset_context.get(), &map_dataset));
core::ScopedUnref scoped_unref_map_dataset(map_dataset);
std::unique_ptr<SerializationContext> serialization_context;
TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
VariantTensorData data;
VariantTensorDataWriter writer(&data);
TF_ASSERT_OK(map_dataset->Save(serialization_context.get(), &writer));
TF_ASSERT_OK(writer.Flush());
}
TEST_F(MapDatasetOpTest, IteratorOutputDtypes) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = TestCase1();

View File

@ -85,7 +85,9 @@ class ModelDatasetOp : public UnaryDatasetOpKernel {
int64 Cardinality() const override { return input_->Cardinality(); }
bool IsStateful() const override { return input_->IsStateful(); }
Status CheckExternalState() const override {
return input_->CheckExternalState();
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,

View File

@ -114,7 +114,9 @@ class PaddedBatchDatasetOp::Dataset : public DatasetBase {
return n / batch_size_ + (n % batch_size_ == 0 || drop_remainder_ ? 0 : 1);
}
bool IsStateful() const override { return input_->IsStateful(); }
Status CheckExternalState() const override {
return input_->CheckExternalState();
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,

View File

@ -785,53 +785,6 @@ TEST_P(ParameterizedPaddedBatchDatasetOpTest, Cardinality) {
test_case.expected_cardinality);
}
TEST_P(ParameterizedPaddedBatchDatasetOpTest, DatasetSave) {
int thread_num = 2, cpu_num = 2;
const TestCase &test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
std::unique_ptr<OpKernel> padded_batch_dataset_kernel;
TF_ASSERT_OK(CreatePaddedBatchDatasetKernel(
test_case.parallel_copy, test_case.n, test_case.expected_output_dtypes,
test_case.expected_output_shapes, &padded_batch_dataset_kernel));
Tensor concatenate_dataset_tensor(DT_VARIANT, TensorShape({}));
TF_ASSERT_OK(CreateConcatenateDatasetTensor(
test_case.input_tensors, test_case.concatenate_output_dtypes,
test_case.concatenate_output_shapes, &concatenate_dataset_tensor));
Tensor batch_size = test_case.batch_size;
std::vector<Tensor> padded_shapes = test_case.padded_shapes;
std::vector<Tensor> padding_values = test_case.padding_values;
Tensor drop_remainder = test_case.drop_remainder;
gtl::InlinedVector<TensorValue, 4> inputs(
{TensorValue(&concatenate_dataset_tensor), TensorValue(&batch_size)});
for (auto &padded_shape : padded_shapes) {
inputs.emplace_back(&padded_shape);
}
for (auto &padding_value : padding_values) {
inputs.emplace_back(&padding_value);
}
inputs.emplace_back(&drop_remainder);
std::unique_ptr<OpKernelContext> padded_batch_dataset_context;
TF_ASSERT_OK(
CreatePaddedBatchDatasetContext(padded_batch_dataset_kernel.get(),
&inputs, &padded_batch_dataset_context));
DatasetBase *padded_batch_dataset;
TF_ASSERT_OK(CreateDataset(padded_batch_dataset_kernel.get(),
padded_batch_dataset_context.get(),
&padded_batch_dataset));
core::ScopedUnref scoped_unref(padded_batch_dataset);
std::unique_ptr<SerializationContext> serialization_ctx;
TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
VariantTensorData data;
VariantTensorDataWriter writer(&data);
TF_ASSERT_OK(padded_batch_dataset->Save(serialization_ctx.get(), &writer));
TF_ASSERT_OK(writer.Flush());
}
TEST_P(ParameterizedPaddedBatchDatasetOpTest, IteratorOutputDtypes) {
int thread_num = 2, cpu_num = 2;
const TestCase &test_case = GetParam();

View File

@ -153,6 +153,11 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
ParallelInterleaveDatasetOp::kDatasetType, params);
}
Status CheckExternalState() const override {
TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
return input_->CheckExternalState();
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,

View File

@ -724,47 +724,6 @@ TEST_P(ParameterizedParallelInterleaveDatasetOpTest, Cardinality) {
test_case.expected_cardinality);
}
TEST_P(ParameterizedParallelInterleaveDatasetOpTest, DatasetSave) {
int thread_num = 2, cpu_num = 2;
const TestCase &test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num));
std::unique_ptr<OpKernel> parallel_interleave_dataset_kernel;
TF_ASSERT_OK(CreateParallelInterleaveDatasetKernel(
test_case.func, test_case.expected_output_dtypes,
test_case.expected_output_shapes, test_case.sloppy,
&parallel_interleave_dataset_kernel));
Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({}));
std::vector<Tensor> inputs_for_tensor_slice_dataset = test_case.input_tensors;
TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset,
&tensor_slice_dataset_tensor));
Tensor cycle_length = test_case.cycle_length;
Tensor block_length = test_case.block_length;
Tensor num_parallel_calls = test_case.num_parallel_calls;
gtl::InlinedVector<TensorValue, 4> inputs(
{TensorValue(&tensor_slice_dataset_tensor), TensorValue(&cycle_length),
TensorValue(&block_length), TensorValue(&num_parallel_calls)});
std::unique_ptr<OpKernelContext> parallel_interleave_dataset_context;
TF_ASSERT_OK(CreateInterleaveDatasetContext(
parallel_interleave_dataset_kernel.get(), &inputs,
&parallel_interleave_dataset_context));
DatasetBase *parallel_interleave_dataset;
TF_ASSERT_OK(CreateDataset(parallel_interleave_dataset_kernel.get(),
parallel_interleave_dataset_context.get(),
&parallel_interleave_dataset));
core::ScopedUnref scoped_unref(parallel_interleave_dataset);
std::unique_ptr<SerializationContext> serialization_ctx;
TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
VariantTensorData data;
VariantTensorDataWriter writer(&data);
TF_ASSERT_OK(
parallel_interleave_dataset->Save(serialization_ctx.get(), &writer));
TF_ASSERT_OK(writer.Flush());
}
TEST_P(ParameterizedParallelInterleaveDatasetOpTest, IteratorOutputDtypes) {
int thread_num = 2, cpu_num = 2;
const TestCase &test_case = GetParam();

View File

@ -89,8 +89,9 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
int64 Cardinality() const override { return input_->Cardinality(); }
bool IsStateful() const override {
return captured_func_->IsStateful() || input_->IsStateful();
Status CheckExternalState() const override {
TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
return input_->CheckExternalState();
}
protected:

View File

@ -527,49 +527,6 @@ TEST_P(ParameterizedParallelMapDatasetOpTest, Cardinality) {
test_case.expected_cardinality);
}
TEST_P(ParameterizedParallelMapDatasetOpTest, DatasetSave) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num));
std::unique_ptr<OpKernel> parallel_map_dataset_kernel;
TF_ASSERT_OK(CreateParallelMapDatasetOpKernel(
test_case.func, test_case.expected_output_dtypes,
test_case.expected_output_shapes, test_case.use_inter_op_parallelism,
test_case.sloppy, test_case.preserve_cardinality,
&parallel_map_dataset_kernel));
DatasetBase* range_dataset;
TF_ASSERT_OK(CreateRangeDataset<int64>(
test_case.range_data_param.start, test_case.range_data_param.end,
test_case.range_data_param.step, "range", &range_dataset));
Tensor range_dataset_tensor(DT_VARIANT, TensorShape({}));
TF_ASSERT_OK(
StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor));
Tensor num_parallel_calls = test_case.num_parallel_calls;
gtl::InlinedVector<TensorValue, 4> parallel_map_dataset_inputs(
{TensorValue(&range_dataset_tensor), TensorValue(&num_parallel_calls)});
std::unique_ptr<OpKernelContext> parallel_map_dataset_context;
TF_ASSERT_OK(CreateParallelMapDatasetContext(
parallel_map_dataset_kernel.get(), &parallel_map_dataset_inputs,
&parallel_map_dataset_context));
DatasetBase* parallel_map_dataset;
TF_ASSERT_OK(CreateDataset(parallel_map_dataset_kernel.get(),
parallel_map_dataset_context.get(),
&parallel_map_dataset));
core::ScopedUnref scoped_unref_map_dataset(parallel_map_dataset);
std::unique_ptr<SerializationContext> serialization_context;
TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
VariantTensorData data;
VariantTensorDataWriter writer(&data);
TF_ASSERT_OK(
parallel_map_dataset->Save(serialization_context.get(), &writer));
TF_ASSERT_OK(writer.Flush());
}
TEST_P(ParameterizedParallelMapDatasetOpTest, IteratorOutputDtypes) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();

View File

@ -83,7 +83,9 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
int64 Cardinality() const override { return input_->Cardinality(); }
bool IsStateful() const override { return input_->IsStateful(); }
Status CheckExternalState() const override {
return input_->CheckExternalState();
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,

View File

@ -398,43 +398,6 @@ TEST_P(ParameterizedPrefetchDatasetOpTest, Cardinality) {
EXPECT_EQ(prefetch_dataset->Cardinality(), test_case.expected_cardinality);
}
TEST_F(PrefetchDatasetOpTest, DatasetSave) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestCase &test_case = PositiveBufferSizeTestCase();
Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({}));
std::vector<Tensor> inputs_for_tensor_slice_dataset = test_case.input_tensors;
TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset,
&tensor_slice_dataset_tensor));
Tensor buffer_size =
CreateTensor<int64>(TensorShape{}, {test_case.buffer_size});
gtl::InlinedVector<TensorValue, 4> inputs_for_prefetch_dataset(
{TensorValue(&tensor_slice_dataset_tensor), TensorValue(&buffer_size)});
std::unique_ptr<OpKernel> prefetch_dataset_kernel;
TF_ASSERT_OK(CreatePrefetchDatasetKernel(test_case.expected_output_dtypes,
test_case.expected_output_shapes,
&prefetch_dataset_kernel));
std::unique_ptr<OpKernelContext> prefetch_dataset_context;
TF_ASSERT_OK(CreatePrefetchDatasetContext(prefetch_dataset_kernel.get(),
&inputs_for_prefetch_dataset,
&prefetch_dataset_context));
DatasetBase *prefetch_dataset;
TF_ASSERT_OK(CreateDataset(prefetch_dataset_kernel.get(),
prefetch_dataset_context.get(),
&prefetch_dataset));
core::ScopedUnref scoped_unref(prefetch_dataset);
std::unique_ptr<SerializationContext> serialization_ctx;
TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
VariantTensorData data;
VariantTensorDataWriter writer(&data);
TF_ASSERT_OK(prefetch_dataset->Save(serialization_ctx.get(), &writer));
TF_ASSERT_OK(writer.Flush());
}
TEST_F(PrefetchDatasetOpTest, IteratorOutputDtypes) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));

View File

@ -73,6 +73,8 @@ class RangeDatasetOp::Dataset : public DatasetBase {
}
}
Status CheckExternalState() const override { return Status::OK(); }
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,

View File

@ -299,28 +299,6 @@ DatasetSaveTestCase<RangeDatasetParams> DatasetSaveTestCase1() {
return {/*dataset_params=*/PositiveStepRangeDataset()};
}
TEST_F(RangeDatasetOpTest, DatasetSave) {
int64 thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
auto test_case = DatasetSaveTestCase1();
gtl::InlinedVector<TensorValue, 4> inputs;
TF_ASSERT_OK(test_case.dataset_params.MakeInputs(&inputs));
std::unique_ptr<OpKernel> range_dataset_kernel;
TF_ASSERT_OK(CreateRangeDatasetOpKernel<int64>(
test_case.dataset_params.node_name, &range_dataset_kernel));
std::unique_ptr<OpKernelContext> range_dataset_context;
TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
&range_dataset_context));
DatasetBase* range_dataset;
TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(),
range_dataset_context.get(), &range_dataset));
core::ScopedUnref scoped_unref(range_dataset);
TF_ASSERT_OK(CheckDatasetSave(*range_dataset));
}
IsStatefulTestCase<RangeDatasetParams> IsStatefulTestCase1() {
return {/*dataset_params=*/PositiveStepRangeDataset(),
/*expected_stateful=*/false};

View File

@ -89,7 +89,9 @@ class RepeatDatasetOp::Dataset : public DatasetBase {
return count_ * n;
}
bool IsStateful() const override { return input_->IsStateful(); }
Status CheckExternalState() const override {
return input_->CheckExternalState();
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,

View File

@ -349,41 +349,6 @@ TEST_P(ParameterizedDatasetOpTest, Cardinality) {
EXPECT_EQ(repeat_dataset->Cardinality(), GetParam().expected_cardinality);
}
TEST_F(RepeatDatasetOpTest, DatasetSave) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestCase &test_case = FiniteRepeatTestCase();
Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({}));
std::vector<Tensor> inputs_for_tensor_slice_dataset = test_case.input_tensors;
TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset,
&tensor_slice_dataset_tensor));
Tensor count = CreateTensor<int64>(TensorShape{}, {test_case.count});
gtl::InlinedVector<TensorValue, 4> inputs_for_repeat_dataset;
inputs_for_repeat_dataset.emplace_back(&tensor_slice_dataset_tensor);
inputs_for_repeat_dataset.emplace_back(&count);
std::unique_ptr<OpKernel> repeat_dataset_kernel;
TF_ASSERT_OK(CreateRepeatDatasetKernel(test_case.expected_output_dtypes,
test_case.expected_output_shapes,
&repeat_dataset_kernel));
std::unique_ptr<OpKernelContext> repeat_dataset_context;
TF_ASSERT_OK(CreateRepeatDatasetContext(repeat_dataset_kernel.get(),
&inputs_for_repeat_dataset,
&repeat_dataset_context));
DatasetBase *repeat_dataset;
TF_ASSERT_OK(CreateDataset(repeat_dataset_kernel.get(),
repeat_dataset_context.get(), &repeat_dataset));
core::ScopedUnref scoped_unref(repeat_dataset);
std::unique_ptr<SerializationContext> serialization_ctx;
TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
VariantTensorData data;
VariantTensorDataWriter writer(&data);
TF_ASSERT_OK(repeat_dataset->Save(serialization_ctx.get(), &writer));
TF_ASSERT_OK(writer.Flush());
}
TEST_P(ParameterizedDatasetOpTest, IteratorOutputDtypes) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));

View File

@ -79,7 +79,9 @@ class ShardDatasetOp::Dataset : public DatasetBase {
return n / num_shards_ + (index_ < n % num_shards_ ? 1 : 0);
}
bool IsStateful() const override { return input_->IsStateful(); }
Status CheckExternalState() const override {
return input_->CheckExternalState();
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,

View File

@ -497,47 +497,6 @@ TEST_P(ParameterizedShardDatasetOpTest, Cardinality) {
EXPECT_EQ(shard_dataset->Cardinality(), test_case.expected_cardinality);
}
TEST_P(ParameterizedShardDatasetOpTest, DatasetSave) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
std::unique_ptr<OpKernel> shard_dataset_kernel;
TF_ASSERT_OK(CreateShardDatasetOpKernel(
test_case.require_non_empty, test_case.expected_output_dtypes,
test_case.expected_output_shapes, &shard_dataset_kernel));
DatasetBase* range_dataset;
TF_ASSERT_OK(CreateRangeDataset<int64>(
test_case.range_dataset_param.start, test_case.range_dataset_param.end,
test_case.range_dataset_param.step, "range", &range_dataset));
Tensor range_dataset_tensor(DT_VARIANT, TensorShape({}));
TF_ASSERT_OK(
StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor));
Tensor num_shards = test_case.num_shards;
Tensor index = test_case.index;
gtl::InlinedVector<TensorValue, 4> inputs({TensorValue(&range_dataset_tensor),
TensorValue(&num_shards),
TensorValue(&index)});
std::unique_ptr<OpKernelContext> shard_dataset_context;
TF_ASSERT_OK(CreateShardDatasetContext(shard_dataset_kernel.get(), &inputs,
&shard_dataset_context));
DatasetBase* shard_dataset;
TF_ASSERT_OK(CreateDataset(shard_dataset_kernel.get(),
shard_dataset_context.get(), &shard_dataset));
core::ScopedUnref scoped_unref_batch_dataset(shard_dataset);
std::unique_ptr<SerializationContext> serialization_context;
TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
VariantTensorData data;
VariantTensorDataWriter writer(&data);
TF_ASSERT_OK(shard_dataset->Save(serialization_context.get(), &writer));
TF_ASSERT_OK(writer.Flush());
}
TEST_P(ParameterizedShardDatasetOpTest, IteratorOutputDtypes) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();

View File

@ -103,7 +103,9 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase {
}
}
bool IsStateful() const override { return input_->IsStateful(); }
Status CheckExternalState() const override {
return input_->CheckExternalState();
}
protected:
template <class T>
@ -540,7 +542,10 @@ class ShuffleDatasetOp::ReshufflingDatasetV2 : public ShuffleDatasetBase {
return name_utils::DatasetDebugString(kDatasetType, params);
}
bool IsStateful() const override { return true; }
Status CheckExternalState() const override {
return errors::FailedPrecondition(
DebugString(), " depends on random seed generator resource.");
}
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {

View File

@ -606,51 +606,6 @@ TEST_P(ParameterizedShuffleDatasetOpTest, Cardinality) {
EXPECT_EQ(dataset->Cardinality(), test_case.expected_cardinality);
}
TEST_P(ParameterizedShuffleDatasetOpTest, DatasetSave) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
Tensor count = test_case.count;
int64 count_value = count.flat<int64>()(0);
std::unique_ptr<OpKernel> dataset_kernel;
TF_ASSERT_OK(
CreateDatasetOpKernel(count_value, test_case.reshuffle_each_iteration,
test_case.expected_output_dtypes,
test_case.expected_output_shapes, &dataset_kernel));
DatasetBase* range_dataset;
TF_ASSERT_OK(CreateRangeDataset<int64>(
test_case.range_data_param.start, test_case.range_data_param.end,
test_case.range_data_param.step, "range", &range_dataset));
Tensor range_dataset_tensor(DT_VARIANT, TensorShape({}));
TF_ASSERT_OK(
StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor));
Tensor buffer_size = test_case.buffer_size;
Tensor seed = test_case.seed;
Tensor seed2 = test_case.seed2;
gtl::InlinedVector<TensorValue, 4> inputs(
{TensorValue(&range_dataset_tensor), TensorValue(&buffer_size),
TensorValue(&seed), TensorValue(&seed2)});
if (count_value != 1) inputs.push_back(TensorValue(&count));
std::unique_ptr<OpKernelContext> dataset_context;
TF_ASSERT_OK(
CreateDatasetContext(dataset_kernel.get(), &inputs, &dataset_context));
DatasetBase* dataset;
TF_ASSERT_OK(
CreateDataset(dataset_kernel.get(), dataset_context.get(), &dataset));
core::ScopedUnref scoped_unref_dataset(dataset);
std::unique_ptr<SerializationContext> serialization_context;
TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
VariantTensorData data;
VariantTensorDataWriter writer(&data);
TF_ASSERT_OK(dataset->Save(serialization_context.get(), &writer));
TF_ASSERT_OK(writer.Flush());
}
TEST_P(ParameterizedShuffleDatasetOpTest, IteratorOutputDtypes) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();

View File

@ -75,7 +75,9 @@ class SkipDatasetOp::Dataset : public DatasetBase {
return count_ < 0 ? 0 : std::max(0LL, n - count_);
}
bool IsStateful() const override { return input_->IsStateful(); }
Status CheckExternalState() const override {
return input_->CheckExternalState();
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,

View File

@ -356,41 +356,6 @@ TEST_P(ParameterizedSkipDatasetOpTest, Cardinality) {
EXPECT_EQ(skip_dataset->Cardinality(), test_case.expected_cardinality);
}
TEST_F(SkipDatasetOpTest, DatasetSave) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestCase &test_case = SkipLessTestCase();
Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({}));
std::vector<Tensor> inputs_for_tensor_slice_dataset = test_case.input_tensors;
TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset,
&tensor_slice_dataset_tensor));
Tensor count = CreateTensor<int64>(TensorShape{}, {test_case.count});
gtl::InlinedVector<TensorValue, 4> inputs_for_skip_dataset(
{TensorValue(&tensor_slice_dataset_tensor), TensorValue(&count)});
std::unique_ptr<OpKernel> skip_dataset_kernel;
TF_ASSERT_OK(CreateSkipDatasetKernel(test_case.expected_output_dtypes,
test_case.expected_output_shapes,
&skip_dataset_kernel));
std::unique_ptr<OpKernelContext> skip_dataset_context;
TF_ASSERT_OK(CreateSkipDatasetContext(skip_dataset_kernel.get(),
&inputs_for_skip_dataset,
&skip_dataset_context));
DatasetBase *skip_dataset;
TF_ASSERT_OK(CreateDataset(skip_dataset_kernel.get(),
skip_dataset_context.get(), &skip_dataset));
core::ScopedUnref scoped_unref(skip_dataset);
std::unique_ptr<SerializationContext> serialization_ctx;
TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
VariantTensorData data;
VariantTensorDataWriter writer(&data);
TF_ASSERT_OK(skip_dataset->Save(serialization_ctx.get(), &writer));
TF_ASSERT_OK(writer.Flush());
}
TEST_P(ParameterizedSkipDatasetOpTest, IteratorOutputDtypes) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));

View File

@ -56,6 +56,8 @@ class Dataset : public DatasetBase {
int64 Cardinality() const override { return sparse_tensor_.shape()[0]; }
Status CheckExternalState() const override { return Status::OK(); }
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,

View File

@ -324,38 +324,6 @@ TEST_P(ParameterizedSparseTensorSliceDatasetOpTest, Cardinality) {
EXPECT_EQ(dataset->Cardinality(), expected_outputs.size());
}
TEST_F(SparseTensorSliceDatasetOpTest, DatasetSave) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestCase &test_case = TwoDimsTestCase();
SparseTensorParam input_sparse_tensor = test_case.input_sparse_tensor;
std::vector<SparseTensorParam> expected_outputs = test_case.expected_outputs;
DataType tvalues = input_sparse_tensor.values.dtype();
gtl::InlinedVector<TensorValue, 4> inputs = {
TensorValue(&input_sparse_tensor.indices),
TensorValue(&input_sparse_tensor.values),
TensorValue(&input_sparse_tensor.dense_shape)};
std::unique_ptr<OpKernel> dataset_kernel;
TF_ASSERT_OK(CreateSparseTensorSliceDatasetKernel(tvalues, &dataset_kernel));
std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
TF_ASSERT_OK(CreateSparseTensorSliceDatasetContext(
dataset_kernel.get(), &inputs, &dataset_kernel_ctx));
DatasetBase *dataset;
TF_ASSERT_OK(
CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), &dataset));
core::ScopedUnref scoped_unref(dataset);
std::unique_ptr<SerializationContext> serialization_ctx;
TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
VariantTensorData data;
VariantTensorDataWriter writer(&data);
TF_ASSERT_OK(dataset->Save(serialization_ctx.get(), &writer));
TF_ASSERT_OK(writer.Flush());
}
TEST_P(ParameterizedSparseTensorSliceDatasetOpTest, IteratorOutputDtypes) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));

View File

@ -71,7 +71,9 @@ int64 TakeDataset::Cardinality() const {
return std::min(n, count_);
}
bool TakeDataset::IsStateful() const { return input_->IsStateful(); }
Status TakeDataset::CheckExternalState() const {
return input_->CheckExternalState();
}
class TakeDataset::EmptyIterator : public DatasetIterator<TakeDataset> {
public:

View File

@ -40,7 +40,7 @@ class TakeDataset : public DatasetBase {
int64 Cardinality() const override;
bool IsStateful() const override;
Status CheckExternalState() const override;
protected:
Status AsGraphDefInternal(SerializationContext* ctx,

View File

@ -351,41 +351,6 @@ TEST_P(ParameterizedTakeDatasetOpTest, Cardinality) {
EXPECT_EQ(take_dataset->Cardinality(), test_case.expected_cardinality);
}
TEST_F(TakeDatasetOpTest, DatasetSave) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestCase &test_case = TakeLessTestCase();
Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({}));
std::vector<Tensor> inputs_for_tensor_slice_dataset = test_case.input_tensors;
TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset,
&tensor_slice_dataset_tensor));
Tensor count = CreateTensor<int64>(TensorShape{}, {test_case.count});
gtl::InlinedVector<TensorValue, 4> inputs_for_take_dataset;
inputs_for_take_dataset.emplace_back(&tensor_slice_dataset_tensor);
inputs_for_take_dataset.emplace_back(&count);
std::unique_ptr<OpKernel> take_dataset_kernel;
TF_ASSERT_OK(CreateTakeDatasetKernel(test_case.expected_output_dtypes,
test_case.expected_output_shapes,
&take_dataset_kernel));
std::unique_ptr<OpKernelContext> take_dataset_context;
TF_ASSERT_OK(CreateTakeDatasetContext(take_dataset_kernel.get(),
&inputs_for_take_dataset,
&take_dataset_context));
DatasetBase *take_dataset;
TF_ASSERT_OK(CreateDataset(take_dataset_kernel.get(),
take_dataset_context.get(), &take_dataset));
core::ScopedUnref scoped_unref(take_dataset);
std::unique_ptr<SerializationContext> serialization_ctx;
TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
VariantTensorData data;
VariantTensorDataWriter writer(&data);
TF_ASSERT_OK(take_dataset->Save(serialization_ctx.get(), &writer));
TF_ASSERT_OK(writer.Flush());
}
TEST_P(ParameterizedTakeDatasetOpTest, IteratorOutputDtypes) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));

View File

@ -62,6 +62,8 @@ class TensorDatasetOp::Dataset : public DatasetBase {
int64 Cardinality() const override { return 1LL; }
Status CheckExternalState() const override { return Status::OK(); }
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
@ -70,12 +72,12 @@ class TensorDatasetOp::Dataset : public DatasetBase {
components.reserve(tensors_.size());
for (const Tensor& t : tensors_) {
Node* node;
if (ctx->optimization_only()) {
if (ctx->serialize_data_tensors()) {
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
} else {
TF_RETURN_IF_ERROR(b->AddPlaceholder(t, &node));
DCHECK_NE(ctx->input_list(), nullptr);
ctx->input_list()->emplace_back(node->name(), t);
} else {
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
}
components.emplace_back(node);
}

View File

@ -305,37 +305,6 @@ TEST_F(TensorDatasetOpTest, Cardinality) {
EXPECT_EQ(tensor_dataset->Cardinality(), test_case.expected_cardinality);
}
TEST_P(ParametrizedTensorDatasetOpTest, DatasetSave) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestCase &test_case = GetParam();
std::vector<Tensor> components = test_case.components;
gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &component : components) {
inputs.push_back(TensorValue(&component));
}
std::unique_ptr<OpKernel> tensor_dataset_kernel;
TF_ASSERT_OK(CreateTensorDatasetKernel(test_case.expected_output_dtypes,
test_case.expected_output_shapes,
&tensor_dataset_kernel));
std::unique_ptr<OpKernelContext> tensor_dataset_context;
TF_ASSERT_OK(CreateTensorDatasetContext(tensor_dataset_kernel.get(), &inputs,
&tensor_dataset_context));
DatasetBase *tensor_dataset;
TF_ASSERT_OK(CreateDataset(tensor_dataset_kernel.get(),
tensor_dataset_context.get(), &tensor_dataset));
core::ScopedUnref scoped_unref(tensor_dataset);
std::unique_ptr<SerializationContext> serialization_context;
TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
VariantTensorData data;
VariantTensorDataWriter writer(&data);
TF_ASSERT_OK(tensor_dataset->Save(serialization_context.get(), &writer));
TF_ASSERT_OK(writer.Flush());
}
TEST_P(ParametrizedTensorDatasetOpTest, IteratorOutputDtypes) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));

View File

@ -68,6 +68,8 @@ class TensorSliceDatasetOp::Dataset : public DatasetBase {
int64 Cardinality() const override { return tensors_[0].dim_size(0); }
Status CheckExternalState() const override { return Status::OK(); }
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
@ -76,12 +78,12 @@ class TensorSliceDatasetOp::Dataset : public DatasetBase {
components.reserve(tensors_.size());
for (const Tensor& t : tensors_) {
Node* node;
if (ctx->optimization_only()) {
if (ctx->serialize_data_tensors()) {
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
} else {
TF_RETURN_IF_ERROR(b->AddPlaceholder(t, &node));
DCHECK_NE(ctx->input_list(), nullptr);
ctx->input_list()->emplace_back(node->name(), t);
} else {
TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
}
components.emplace_back(node);
}

View File

@ -387,48 +387,6 @@ TEST_P(ParameterizedTensorSliceDatasetOpTest, Cardinality) {
EXPECT_EQ(tensor_slice_dataset->Cardinality(), inputs[0].tensor->dim_size(0));
}
TEST_F(TensorSliceDatasetOpTest, DatasetSave) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestCase &test_case = PlainTensorTestCase();
const std::vector<Tensor> &expected_outputs = test_case.expected_outputs;
std::vector<Tensor> components = test_case.components;
DataTypeVector dtypes;
gtl::InlinedVector<TensorValue, 4> inputs;
for (auto &component : components) {
inputs.emplace_back(&component);
dtypes.emplace_back(component.dtype());
}
size_t num_tensors_per_slice = components.size();
std::vector<PartialTensorShape> shapes;
shapes.reserve(num_tensors_per_slice);
for (int i = 0; i < num_tensors_per_slice; ++i) {
shapes.emplace_back(expected_outputs[i].shape());
}
std::unique_ptr<OpKernel> tensor_slice_dataset_kernel;
TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes,
&tensor_slice_dataset_kernel));
std::unique_ptr<OpKernelContext> tensor_slice_dataset_context;
TF_ASSERT_OK(
CreateTensorSliceDatasetContext(tensor_slice_dataset_kernel.get(),
&inputs, &tensor_slice_dataset_context));
DatasetBase *tensor_slice_dataset;
TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(),
tensor_slice_dataset_context.get(),
&tensor_slice_dataset));
core::ScopedUnref scoped_unref(tensor_slice_dataset);
std::unique_ptr<SerializationContext> serialization_context;
TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
VariantTensorData data;
VariantTensorDataWriter writer(&data);
TF_ASSERT_OK(
tensor_slice_dataset->Save(serialization_context.get(), &writer));
TF_ASSERT_OK(writer.Flush());
}
TEST_P(ParameterizedTensorSliceDatasetOpTest, IteratorOutputDtypes) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));

View File

@ -70,6 +70,8 @@ class TextLineDatasetOp::Dataset : public DatasetBase {
return name_utils::DatasetDebugString(kDatasetType);
}
Status CheckExternalState() const override { return Status::OK(); }
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,

View File

@ -366,45 +366,6 @@ TEST_P(ParameterizedTextLineDatasetOpTest, Cardinality) {
EXPECT_EQ(text_line_dataset->Cardinality(), test_case.expected_cardinality);
}
TEST_P(ParameterizedTextLineDatasetOpTest, DatasetSave) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
TF_ASSERT_OK(CreateTestFiles(test_case));
std::unique_ptr<OpKernel> text_line_dataset_kernel;
TF_ASSERT_OK(CreateTextLineDatasetOpKernel(&text_line_dataset_kernel));
int64 num_files = test_case.filenames.size();
Tensor filenames =
CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
Tensor compression_type = CreateTensor<string>(
TensorShape({}), {ToString(test_case.compression_type)});
Tensor buffer_size =
CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
gtl::InlinedVector<TensorValue, 4> inputs{TensorValue(&filenames),
TensorValue(&compression_type),
TensorValue(&buffer_size)};
std::unique_ptr<OpKernelContext> text_line_dataset_context;
TF_ASSERT_OK(CreateTextLineDatasetContext(
text_line_dataset_kernel.get(), &inputs, &text_line_dataset_context));
DatasetBase* text_line_dataset;
TF_ASSERT_OK(CreateDataset(text_line_dataset_kernel.get(),
text_line_dataset_context.get(),
&text_line_dataset));
core::ScopedUnref scoped_unref(text_line_dataset);
std::unique_ptr<SerializationContext> serialization_context;
TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
VariantTensorData data;
VariantTensorDataWriter writer(&data);
TF_ASSERT_OK(text_line_dataset->Save(serialization_context.get(), &writer));
TF_ASSERT_OK(writer.Flush());
}
TEST_P(ParameterizedTextLineDatasetOpTest, IteratorOutputDtypes) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();

View File

@ -74,6 +74,8 @@ class TFRecordDatasetOp::Dataset : public DatasetBase {
return name_utils::DatasetDebugString(kDatasetType);
}
Status CheckExternalState() const override { return Status::OK(); }
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,

View File

@ -361,45 +361,6 @@ TEST_P(ParameterizedTFRecordDatasetOpTest, Cardinality) {
EXPECT_EQ(tf_record_dataset->Cardinality(), test_case.expected_cardinality);
}
TEST_P(ParameterizedTFRecordDatasetOpTest, DatasetSave) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
TF_ASSERT_OK(CreateTestFiles(test_case));
std::unique_ptr<OpKernel> tf_record_dataset_kernel;
TF_ASSERT_OK(CreateTFRecordDatasetOpKernel(&tf_record_dataset_kernel));
int64 num_files = test_case.filenames.size();
Tensor filenames =
CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
Tensor compression_type = CreateTensor<string>(
TensorShape({}), {ToString(test_case.compression_type)});
Tensor buffer_size =
CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
gtl::InlinedVector<TensorValue, 4> inputs{TensorValue(&filenames),
TensorValue(&compression_type),
TensorValue(&buffer_size)};
std::unique_ptr<OpKernelContext> tf_record_dataset_context;
TF_ASSERT_OK(CreateTFRecordDatasetContext(
tf_record_dataset_kernel.get(), &inputs, &tf_record_dataset_context));
DatasetBase* tf_record_dataset;
TF_ASSERT_OK(CreateDataset(tf_record_dataset_kernel.get(),
tf_record_dataset_context.get(),
&tf_record_dataset));
core::ScopedUnref scoped_unref(tf_record_dataset);
std::unique_ptr<SerializationContext> serialization_context;
TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
VariantTensorData data;
VariantTensorDataWriter writer(&data);
TF_ASSERT_OK(tf_record_dataset->Save(serialization_context.get(), &writer));
TF_ASSERT_OK(writer.Flush());
}
TEST_P(ParameterizedTFRecordDatasetOpTest, IteratorOutputDtypes) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();

View File

@ -59,6 +59,8 @@ class WindowDataset : public DatasetBase {
string DebugString() const override { return kWindowDataset; }
Status CheckExternalState() const override { return Status::OK(); }
protected:
// TODO(b/110981596): Support checkpointing.
Status AsGraphDefInternal(SerializationContext* ctx,

View File

@ -99,7 +99,9 @@ class WindowDatasetOp::Dataset : public DatasetBase {
return cardinality;
}
bool IsStateful() const override { return input_->IsStateful(); }
Status CheckExternalState() const override {
return input_->CheckExternalState();
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,

View File

@ -574,49 +574,6 @@ TEST_P(ParameterizedWindowDatasetOpTest, Cardinality) {
EXPECT_EQ(dataset->Cardinality(), test_case.expected_cardinality);
}
TEST_P(ParameterizedWindowDatasetOpTest, DatasetSave) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
std::unique_ptr<OpKernel> window_dataset_kernel;
TF_ASSERT_OK(CreateWindowDatasetKernel(test_case.expected_output_dtypes,
test_case.expected_output_shapes,
&window_dataset_kernel));
DatasetBase* range_dataset;
TF_ASSERT_OK(CreateRangeDataset<int64>(
test_case.range_data_param.start, test_case.range_data_param.end,
test_case.range_data_param.step, "range", &range_dataset));
Tensor range_dataset_tensor(DT_VARIANT, TensorShape({}));
TF_ASSERT_OK(
StoreDatasetInVariantTensor(range_dataset, &range_dataset_tensor));
Tensor size = test_case.size;
Tensor shift = test_case.shift;
Tensor stride = test_case.stride;
Tensor drop_remainder = test_case.drop_remainder;
gtl::InlinedVector<TensorValue, 4> inputs(
{TensorValue(&range_dataset_tensor), TensorValue(&size),
TensorValue(&shift), TensorValue(&stride),
TensorValue(&drop_remainder)});
std::unique_ptr<OpKernelContext> window_dataset_op_ctx;
TF_ASSERT_OK(CreateWindowDatasetContext(window_dataset_kernel.get(), &inputs,
&window_dataset_op_ctx));
DatasetBase* dataset;
TF_ASSERT_OK(CreateDataset(window_dataset_kernel.get(),
window_dataset_op_ctx.get(), &dataset));
core::ScopedUnref scoped_unref_dataset(dataset);
std::unique_ptr<SerializationContext> serialization_context;
TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
VariantTensorData data;
VariantTensorDataWriter writer(&data);
TF_ASSERT_OK(dataset->Save(serialization_context.get(), &writer));
TF_ASSERT_OK(writer.Flush());
}
TEST_P(ParameterizedWindowDatasetOpTest, IteratorOutputDtypes) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();

View File

@ -87,6 +87,13 @@ class ZipDatasetOp::Dataset : public DatasetBase {
return result;
}
Status CheckExternalState() const override {
for (const auto& input : inputs_) {
TF_RETURN_IF_ERROR(input->CheckExternalState());
}
return Status::OK();
}
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
@ -103,15 +110,6 @@ class ZipDatasetOp::Dataset : public DatasetBase {
return Status::OK();
}
bool IsStateful() const override {
for (const auto& input : inputs_) {
if (input->IsStateful()) {
return true;
}
}
return false;
}
private:
class Iterator : public DatasetIterator<Dataset> {
public:

View File

@ -333,41 +333,6 @@ TEST_P(ParameterizedZipDatasetOpTest, Cardinality) {
test_case.expected_outputs.size() / num_tensors_per_slice);
}
TEST_F(ZipDatasetOpTest, DatasetSave) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestParam &test_case = TestCase1();
std::vector<Tensor> range_dataset_tensors;
range_dataset_tensors.reserve(test_case.input_range_dataset_params.size());
TF_ASSERT_OK(CreateRangeDatasetTensors(test_case.input_range_dataset_params,
&range_dataset_tensors));
gtl::InlinedVector<TensorValue, 4> inputs;
inputs.reserve(range_dataset_tensors.size());
for (auto &tensor : range_dataset_tensors) {
inputs.emplace_back(&tensor);
}
std::unique_ptr<OpKernel> dataset_kernel;
int num_tensors_per_slice = test_case.input_range_dataset_params.size();
TF_ASSERT_OK(CreateZipDatasetKernel({DT_INT64}, {{num_tensors_per_slice}},
inputs.size(), &dataset_kernel));
std::unique_ptr<OpKernelContext> dataset_kernel_ctx;
TF_ASSERT_OK(CreateZipDatasetContext(dataset_kernel.get(), &inputs,
&dataset_kernel_ctx));
DatasetBase *zip_dataset;
TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(),
&zip_dataset));
core::ScopedUnref scoped_unref(zip_dataset);
std::unique_ptr<SerializationContext> serialization_ctx;
TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
VariantTensorData data;
VariantTensorDataWriter writer(&data);
TF_ASSERT_OK(zip_dataset->Save(serialization_ctx.get(), &writer));
TF_ASSERT_OK(writer.Flush());
}
TEST_P(ParameterizedZipDatasetOpTest, IteratorOutputDtypes) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));

View File

@ -33,6 +33,7 @@ from tensorflow.python.platform import test
class ReplicateClusterTest(test_base.DatasetTestBase):
def setUp(self):
super(ReplicateClusterTest, self).setUp()
# Start the local server.
worker_config = config_pb2.ConfigProto()
worker_config.device_count["CPU"] = 2
@ -99,7 +100,7 @@ class ReplicateClusterTest(test_base.DatasetTestBase):
it1 = dataset_ops.make_initializable_iterator(dataset1)
# We don't support stateful ops in functions as of now.
with session.Session(self._target) as sess:
with self.assertRaises(errors.InvalidArgumentError):
with self.assertRaises(errors.FailedPreconditionError):
sess.run(it1.initializer)

Some files were not shown because too many files have changed in this diff Show More