Pulling out kExternalStatePolicy as its own const char instead of re-using the DatasetToGraphOp constant
PiperOrigin-RevId: 286386422 Change-Id: I4432145c3bbe4903ff121b3e0fb9b488af30a610
This commit is contained in:
parent
de5705ecbe
commit
69111e174c
@ -1115,7 +1115,6 @@ tf_kernel_library(
|
||||
hdrs = ["iterator_ops.h"],
|
||||
deps = [
|
||||
":captured_function",
|
||||
":dataset_ops",
|
||||
":dataset_utils",
|
||||
":optional_ops",
|
||||
":unbounded_thread_pool",
|
||||
|
@ -33,7 +33,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/variant_tensor_data.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/kernels/data/captured_function.h"
|
||||
#include "tensorflow/core/kernels/data/dataset_ops.h"
|
||||
#include "tensorflow/core/kernels/data/dataset_utils.h"
|
||||
#include "tensorflow/core/kernels/data/optional_ops.h"
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
@ -63,6 +62,9 @@ const char kOutputTypes[] = "output_types";
|
||||
|
||||
} // namespace
|
||||
|
||||
/* static */ constexpr const char* const
|
||||
SerializeIteratorOp::kExternalStatePolicy;
|
||||
|
||||
Status IteratorResource::GetNext(OpKernelContext* ctx,
|
||||
std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) {
|
||||
@ -1067,21 +1069,18 @@ void IteratorFromStringHandleOp::Compute(OpKernelContext* ctx) {
|
||||
resource_handle_t->scalar<ResourceHandle>()() = resource_handle;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
class SerializeIteratorOp : public OpKernel {
|
||||
public:
|
||||
explicit SerializeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
if (ctx->HasAttr(DatasetToGraphOp::kExternalStatePolicy)) {
|
||||
SerializeIteratorOp::SerializeIteratorOp(OpKernelConstruction* ctx)
|
||||
: OpKernel(ctx) {
|
||||
if (ctx->HasAttr(kExternalStatePolicy)) {
|
||||
int64 state_change_option;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(DatasetToGraphOp::kExternalStatePolicy,
|
||||
&state_change_option));
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->GetAttr(kExternalStatePolicy, &state_change_option));
|
||||
external_state_policy_ =
|
||||
SerializationContext::ExternalStatePolicy(state_change_option);
|
||||
}
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
void SerializeIteratorOp::Compute(OpKernelContext* ctx) {
|
||||
const Tensor& resource_handle_t = ctx->input(0);
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
|
||||
errors::InvalidArgument("resource_handle must be a scalar"));
|
||||
@ -1098,22 +1097,13 @@ class SerializeIteratorOp : public OpKernel {
|
||||
OP_REQUIRES_OK(ctx, serializer.InitializeFromIterator(&serialization_ctx,
|
||||
iterator_resource));
|
||||
Tensor* serialized_t;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->allocate_output(0, TensorShape({serializer.NumTensors()}),
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->allocate_output(0, TensorShape({serializer.NumTensors()}),
|
||||
&serialized_t));
|
||||
OP_REQUIRES_OK(ctx, serializer.Serialize(serialized_t));
|
||||
}
|
||||
|
||||
private:
|
||||
SerializationContext::ExternalStatePolicy external_state_policy_ =
|
||||
SerializationContext::ExternalStatePolicy::kWarn;
|
||||
};
|
||||
|
||||
class DeserializeIteratorOp : public OpKernel {
|
||||
public:
|
||||
explicit DeserializeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
void DeserializeIteratorOp::Compute(OpKernelContext* ctx) {
|
||||
// Validate that the handle corresponds to a real resource, and
|
||||
// that it is an IteratorResource.
|
||||
IteratorResource* iterator_resource;
|
||||
@ -1124,10 +1114,10 @@ class DeserializeIteratorOp : public OpKernel {
|
||||
OP_REQUIRES_OK(ctx, ctx->input("serialized", &serialized_t));
|
||||
IteratorVariantSerializer serializer;
|
||||
OP_REQUIRES_OK(ctx, serializer.InitFromTensor(serialized_t));
|
||||
OP_REQUIRES_OK(ctx,
|
||||
iterator_resource->Restore(ctx, serializer.GetReader()));
|
||||
OP_REQUIRES_OK(ctx, iterator_resource->Restore(ctx, serializer.GetReader()));
|
||||
}
|
||||
};
|
||||
|
||||
namespace {
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("Iterator").Device(DEVICE_CPU), IteratorHandleOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("IteratorV2").Device(DEVICE_CPU).Priority(2),
|
||||
|
@ -249,6 +249,27 @@ class IteratorFromStringHandleOp : public OpKernel {
|
||||
std::vector<PartialTensorShape> output_shapes_;
|
||||
};
|
||||
|
||||
class SerializeIteratorOp : public OpKernel {
|
||||
public:
|
||||
static constexpr const char* const kExternalStatePolicy =
|
||||
"external_state_policy";
|
||||
|
||||
explicit SerializeIteratorOp(OpKernelConstruction* ctx);
|
||||
|
||||
void Compute(OpKernelContext* ctx) override;
|
||||
|
||||
private:
|
||||
SerializationContext::ExternalStatePolicy external_state_policy_ =
|
||||
SerializationContext::ExternalStatePolicy::kWarn;
|
||||
};
|
||||
|
||||
class DeserializeIteratorOp : public OpKernel {
|
||||
public:
|
||||
explicit DeserializeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override;
|
||||
};
|
||||
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user