From 69111e174ca40b6613276ffc0ce8bc7413f0401b Mon Sep 17 00:00:00 2001 From: Rohan Jain Date: Thu, 19 Dec 2019 07:08:28 -0800 Subject: [PATCH] Pulling out kExternalStatePolicy as its own const char instead of re-using the DatasetToGraphOp constant PiperOrigin-RevId: 286386422 Change-Id: I4432145c3bbe4903ff121b3e0fb9b488af30a610 --- tensorflow/core/kernels/data/BUILD | 1 - tensorflow/core/kernels/data/iterator_ops.cc | 110 +++++++++---------- tensorflow/core/kernels/data/iterator_ops.h | 21 ++++ 3 files changed, 71 insertions(+), 61 deletions(-) diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD index bc1e6fc996c..81d4002c3fd 100644 --- a/tensorflow/core/kernels/data/BUILD +++ b/tensorflow/core/kernels/data/BUILD @@ -1115,7 +1115,6 @@ tf_kernel_library( hdrs = ["iterator_ops.h"], deps = [ ":captured_function", - ":dataset_ops", ":dataset_utils", ":optional_ops", ":unbounded_thread_pool", diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc index 7ad08c8174c..b803f779145 100644 --- a/tensorflow/core/kernels/data/iterator_ops.cc +++ b/tensorflow/core/kernels/data/iterator_ops.cc @@ -33,7 +33,6 @@ limitations under the License. #include "tensorflow/core/framework/variant_tensor_data.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/kernels/data/captured_function.h" -#include "tensorflow/core/kernels/data/dataset_ops.h" #include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/kernels/data/optional_ops.h" #include "tensorflow/core/kernels/ops_util.h" @@ -63,6 +62,9 @@ const char kOutputTypes[] = "output_types"; } // namespace +/* static */ constexpr const char* const + SerializeIteratorOp::kExternalStatePolicy; + Status IteratorResource::GetNext(OpKernelContext* ctx, std::vector* out_tensors, bool* end_of_sequence) { @@ -1067,67 +1069,55 @@ void IteratorFromStringHandleOp::Compute(OpKernelContext* ctx) { resource_handle_t->scalar()() = resource_handle; } -namespace { - -class SerializeIteratorOp : public OpKernel { - public: - explicit SerializeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) { - if (ctx->HasAttr(DatasetToGraphOp::kExternalStatePolicy)) { - int64 state_change_option; - OP_REQUIRES_OK(ctx, ctx->GetAttr(DatasetToGraphOp::kExternalStatePolicy, - &state_change_option)); - external_state_policy_ = - SerializationContext::ExternalStatePolicy(state_change_option); - } - } - - void Compute(OpKernelContext* ctx) override { - const Tensor& resource_handle_t = ctx->input(0); - OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()), - errors::InvalidArgument("resource_handle must be a scalar")); - // Validate that the handle corresponds to a real resource, and - // that it is an IteratorResource. - IteratorResource* iterator_resource; - OP_REQUIRES_OK( - ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource)); - core::ScopedUnref unref_iterator(iterator_resource); - IteratorVariantSerializer serializer; - SerializationContext::Params params; - params.external_state_policy = external_state_policy_; - SerializationContext serialization_ctx(params); - OP_REQUIRES_OK(ctx, serializer.InitializeFromIterator(&serialization_ctx, - iterator_resource)); - Tensor* serialized_t; - OP_REQUIRES_OK( - ctx, ctx->allocate_output(0, TensorShape({serializer.NumTensors()}), - &serialized_t)); - OP_REQUIRES_OK(ctx, serializer.Serialize(serialized_t)); - } - - private: - SerializationContext::ExternalStatePolicy external_state_policy_ = - SerializationContext::ExternalStatePolicy::kWarn; -}; - -class DeserializeIteratorOp : public OpKernel { - public: - explicit DeserializeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} - - void Compute(OpKernelContext* ctx) override { - // Validate that the handle corresponds to a real resource, and - // that it is an IteratorResource. - IteratorResource* iterator_resource; - OP_REQUIRES_OK( - ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource)); - core::ScopedUnref unref_iterator(iterator_resource); - const Tensor* serialized_t; - OP_REQUIRES_OK(ctx, ctx->input("serialized", &serialized_t)); - IteratorVariantSerializer serializer; - OP_REQUIRES_OK(ctx, serializer.InitFromTensor(serialized_t)); +SerializeIteratorOp::SerializeIteratorOp(OpKernelConstruction* ctx) + : OpKernel(ctx) { + if (ctx->HasAttr(kExternalStatePolicy)) { + int64 state_change_option; OP_REQUIRES_OK(ctx, - iterator_resource->Restore(ctx, serializer.GetReader())); + ctx->GetAttr(kExternalStatePolicy, &state_change_option)); + external_state_policy_ = + SerializationContext::ExternalStatePolicy(state_change_option); } -}; +} + +void SerializeIteratorOp::Compute(OpKernelContext* ctx) { + const Tensor& resource_handle_t = ctx->input(0); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()), + errors::InvalidArgument("resource_handle must be a scalar")); + // Validate that the handle corresponds to a real resource, and + // that it is an IteratorResource. + IteratorResource* iterator_resource; + OP_REQUIRES_OK( + ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource)); + core::ScopedUnref unref_iterator(iterator_resource); + IteratorVariantSerializer serializer; + SerializationContext::Params params; + params.external_state_policy = external_state_policy_; + SerializationContext serialization_ctx(params); + OP_REQUIRES_OK(ctx, serializer.InitializeFromIterator(&serialization_ctx, + iterator_resource)); + Tensor* serialized_t; + OP_REQUIRES_OK(ctx, + ctx->allocate_output(0, TensorShape({serializer.NumTensors()}), + &serialized_t)); + OP_REQUIRES_OK(ctx, serializer.Serialize(serialized_t)); +} + +void DeserializeIteratorOp::Compute(OpKernelContext* ctx) { + // Validate that the handle corresponds to a real resource, and + // that it is an IteratorResource. + IteratorResource* iterator_resource; + OP_REQUIRES_OK( + ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource)); + core::ScopedUnref unref_iterator(iterator_resource); + const Tensor* serialized_t; + OP_REQUIRES_OK(ctx, ctx->input("serialized", &serialized_t)); + IteratorVariantSerializer serializer; + OP_REQUIRES_OK(ctx, serializer.InitFromTensor(serialized_t)); + OP_REQUIRES_OK(ctx, iterator_resource->Restore(ctx, serializer.GetReader())); +} + +namespace { REGISTER_KERNEL_BUILDER(Name("Iterator").Device(DEVICE_CPU), IteratorHandleOp); REGISTER_KERNEL_BUILDER(Name("IteratorV2").Device(DEVICE_CPU).Priority(2), diff --git a/tensorflow/core/kernels/data/iterator_ops.h b/tensorflow/core/kernels/data/iterator_ops.h index f45fdaf0f19..4f43d787c44 100644 --- a/tensorflow/core/kernels/data/iterator_ops.h +++ b/tensorflow/core/kernels/data/iterator_ops.h @@ -249,6 +249,27 @@ class IteratorFromStringHandleOp : public OpKernel { std::vector output_shapes_; }; +class SerializeIteratorOp : public OpKernel { + public: + static constexpr const char* const kExternalStatePolicy = + "external_state_policy"; + + explicit SerializeIteratorOp(OpKernelConstruction* ctx); + + void Compute(OpKernelContext* ctx) override; + + private: + SerializationContext::ExternalStatePolicy external_state_policy_ = + SerializationContext::ExternalStatePolicy::kWarn; +}; + +class DeserializeIteratorOp : public OpKernel { + public: + explicit DeserializeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override; +}; + } // namespace data } // namespace tensorflow