Pulling out kExternalStatePolicy as its own const char instead of re-using the DatasetToGraphOp constant

PiperOrigin-RevId: 286386422
Change-Id: I4432145c3bbe4903ff121b3e0fb9b488af30a610
This commit is contained in:
Rohan Jain 2019-12-19 07:08:28 -08:00 committed by TensorFlower Gardener
parent de5705ecbe
commit 69111e174c
3 changed files with 71 additions and 61 deletions

View File

@ -1115,7 +1115,6 @@ tf_kernel_library(
hdrs = ["iterator_ops.h"],
deps = [
":captured_function",
":dataset_ops",
":dataset_utils",
":optional_ops",
":unbounded_thread_pool",

View File

@ -33,7 +33,6 @@ limitations under the License.
#include "tensorflow/core/framework/variant_tensor_data.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/kernels/data/captured_function.h"
#include "tensorflow/core/kernels/data/dataset_ops.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/kernels/data/optional_ops.h"
#include "tensorflow/core/kernels/ops_util.h"
@ -63,6 +62,9 @@ const char kOutputTypes[] = "output_types";
} // namespace
/* static */ constexpr const char* const
SerializeIteratorOp::kExternalStatePolicy;
Status IteratorResource::GetNext(OpKernelContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) {
@ -1067,21 +1069,18 @@ void IteratorFromStringHandleOp::Compute(OpKernelContext* ctx) {
resource_handle_t->scalar<ResourceHandle>()() = resource_handle;
}
namespace {
class SerializeIteratorOp : public OpKernel {
public:
explicit SerializeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
if (ctx->HasAttr(DatasetToGraphOp::kExternalStatePolicy)) {
SerializeIteratorOp::SerializeIteratorOp(OpKernelConstruction* ctx)
: OpKernel(ctx) {
if (ctx->HasAttr(kExternalStatePolicy)) {
int64 state_change_option;
OP_REQUIRES_OK(ctx, ctx->GetAttr(DatasetToGraphOp::kExternalStatePolicy,
&state_change_option));
OP_REQUIRES_OK(ctx,
ctx->GetAttr(kExternalStatePolicy, &state_change_option));
external_state_policy_ =
SerializationContext::ExternalStatePolicy(state_change_option);
}
}
void Compute(OpKernelContext* ctx) override {
void SerializeIteratorOp::Compute(OpKernelContext* ctx) {
const Tensor& resource_handle_t = ctx->input(0);
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
errors::InvalidArgument("resource_handle must be a scalar"));
@ -1098,22 +1097,13 @@ class SerializeIteratorOp : public OpKernel {
OP_REQUIRES_OK(ctx, serializer.InitializeFromIterator(&serialization_ctx,
iterator_resource));
Tensor* serialized_t;
OP_REQUIRES_OK(
ctx, ctx->allocate_output(0, TensorShape({serializer.NumTensors()}),
OP_REQUIRES_OK(ctx,
ctx->allocate_output(0, TensorShape({serializer.NumTensors()}),
&serialized_t));
OP_REQUIRES_OK(ctx, serializer.Serialize(serialized_t));
}
private:
SerializationContext::ExternalStatePolicy external_state_policy_ =
SerializationContext::ExternalStatePolicy::kWarn;
};
class DeserializeIteratorOp : public OpKernel {
public:
explicit DeserializeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
void DeserializeIteratorOp::Compute(OpKernelContext* ctx) {
// Validate that the handle corresponds to a real resource, and
// that it is an IteratorResource.
IteratorResource* iterator_resource;
@ -1124,10 +1114,10 @@ class DeserializeIteratorOp : public OpKernel {
OP_REQUIRES_OK(ctx, ctx->input("serialized", &serialized_t));
IteratorVariantSerializer serializer;
OP_REQUIRES_OK(ctx, serializer.InitFromTensor(serialized_t));
OP_REQUIRES_OK(ctx,
iterator_resource->Restore(ctx, serializer.GetReader()));
OP_REQUIRES_OK(ctx, iterator_resource->Restore(ctx, serializer.GetReader()));
}
};
namespace {
REGISTER_KERNEL_BUILDER(Name("Iterator").Device(DEVICE_CPU), IteratorHandleOp);
REGISTER_KERNEL_BUILDER(Name("IteratorV2").Device(DEVICE_CPU).Priority(2),

View File

@ -249,6 +249,27 @@ class IteratorFromStringHandleOp : public OpKernel {
std::vector<PartialTensorShape> output_shapes_;
};
class SerializeIteratorOp : public OpKernel {
public:
static constexpr const char* const kExternalStatePolicy =
"external_state_policy";
explicit SerializeIteratorOp(OpKernelConstruction* ctx);
void Compute(OpKernelContext* ctx) override;
private:
SerializationContext::ExternalStatePolicy external_state_policy_ =
SerializationContext::ExternalStatePolicy::kWarn;
};
class DeserializeIteratorOp : public OpKernel {
public:
explicit DeserializeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override;
};
} // namespace data
} // namespace tensorflow