Pulling out kExternalStatePolicy as its own const char instead of re-using the DatasetToGraphOp constant
PiperOrigin-RevId: 286386422 Change-Id: I4432145c3bbe4903ff121b3e0fb9b488af30a610
This commit is contained in:
parent
de5705ecbe
commit
69111e174c
@ -1115,7 +1115,6 @@ tf_kernel_library(
|
|||||||
hdrs = ["iterator_ops.h"],
|
hdrs = ["iterator_ops.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":captured_function",
|
":captured_function",
|
||||||
":dataset_ops",
|
|
||||||
":dataset_utils",
|
":dataset_utils",
|
||||||
":optional_ops",
|
":optional_ops",
|
||||||
":unbounded_thread_pool",
|
":unbounded_thread_pool",
|
||||||
|
@ -33,7 +33,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/variant_tensor_data.h"
|
#include "tensorflow/core/framework/variant_tensor_data.h"
|
||||||
#include "tensorflow/core/graph/graph_constructor.h"
|
#include "tensorflow/core/graph/graph_constructor.h"
|
||||||
#include "tensorflow/core/kernels/data/captured_function.h"
|
#include "tensorflow/core/kernels/data/captured_function.h"
|
||||||
#include "tensorflow/core/kernels/data/dataset_ops.h"
|
|
||||||
#include "tensorflow/core/kernels/data/dataset_utils.h"
|
#include "tensorflow/core/kernels/data/dataset_utils.h"
|
||||||
#include "tensorflow/core/kernels/data/optional_ops.h"
|
#include "tensorflow/core/kernels/data/optional_ops.h"
|
||||||
#include "tensorflow/core/kernels/ops_util.h"
|
#include "tensorflow/core/kernels/ops_util.h"
|
||||||
@ -63,6 +62,9 @@ const char kOutputTypes[] = "output_types";
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
/* static */ constexpr const char* const
|
||||||
|
SerializeIteratorOp::kExternalStatePolicy;
|
||||||
|
|
||||||
Status IteratorResource::GetNext(OpKernelContext* ctx,
|
Status IteratorResource::GetNext(OpKernelContext* ctx,
|
||||||
std::vector<Tensor>* out_tensors,
|
std::vector<Tensor>* out_tensors,
|
||||||
bool* end_of_sequence) {
|
bool* end_of_sequence) {
|
||||||
@ -1067,67 +1069,55 @@ void IteratorFromStringHandleOp::Compute(OpKernelContext* ctx) {
|
|||||||
resource_handle_t->scalar<ResourceHandle>()() = resource_handle;
|
resource_handle_t->scalar<ResourceHandle>()() = resource_handle;
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
SerializeIteratorOp::SerializeIteratorOp(OpKernelConstruction* ctx)
|
||||||
|
: OpKernel(ctx) {
|
||||||
class SerializeIteratorOp : public OpKernel {
|
if (ctx->HasAttr(kExternalStatePolicy)) {
|
||||||
public:
|
int64 state_change_option;
|
||||||
explicit SerializeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
|
||||||
if (ctx->HasAttr(DatasetToGraphOp::kExternalStatePolicy)) {
|
|
||||||
int64 state_change_option;
|
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(DatasetToGraphOp::kExternalStatePolicy,
|
|
||||||
&state_change_option));
|
|
||||||
external_state_policy_ =
|
|
||||||
SerializationContext::ExternalStatePolicy(state_change_option);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override {
|
|
||||||
const Tensor& resource_handle_t = ctx->input(0);
|
|
||||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
|
|
||||||
errors::InvalidArgument("resource_handle must be a scalar"));
|
|
||||||
// Validate that the handle corresponds to a real resource, and
|
|
||||||
// that it is an IteratorResource.
|
|
||||||
IteratorResource* iterator_resource;
|
|
||||||
OP_REQUIRES_OK(
|
|
||||||
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
|
|
||||||
core::ScopedUnref unref_iterator(iterator_resource);
|
|
||||||
IteratorVariantSerializer serializer;
|
|
||||||
SerializationContext::Params params;
|
|
||||||
params.external_state_policy = external_state_policy_;
|
|
||||||
SerializationContext serialization_ctx(params);
|
|
||||||
OP_REQUIRES_OK(ctx, serializer.InitializeFromIterator(&serialization_ctx,
|
|
||||||
iterator_resource));
|
|
||||||
Tensor* serialized_t;
|
|
||||||
OP_REQUIRES_OK(
|
|
||||||
ctx, ctx->allocate_output(0, TensorShape({serializer.NumTensors()}),
|
|
||||||
&serialized_t));
|
|
||||||
OP_REQUIRES_OK(ctx, serializer.Serialize(serialized_t));
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
SerializationContext::ExternalStatePolicy external_state_policy_ =
|
|
||||||
SerializationContext::ExternalStatePolicy::kWarn;
|
|
||||||
};
|
|
||||||
|
|
||||||
class DeserializeIteratorOp : public OpKernel {
|
|
||||||
public:
|
|
||||||
explicit DeserializeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override {
|
|
||||||
// Validate that the handle corresponds to a real resource, and
|
|
||||||
// that it is an IteratorResource.
|
|
||||||
IteratorResource* iterator_resource;
|
|
||||||
OP_REQUIRES_OK(
|
|
||||||
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
|
|
||||||
core::ScopedUnref unref_iterator(iterator_resource);
|
|
||||||
const Tensor* serialized_t;
|
|
||||||
OP_REQUIRES_OK(ctx, ctx->input("serialized", &serialized_t));
|
|
||||||
IteratorVariantSerializer serializer;
|
|
||||||
OP_REQUIRES_OK(ctx, serializer.InitFromTensor(serialized_t));
|
|
||||||
OP_REQUIRES_OK(ctx,
|
OP_REQUIRES_OK(ctx,
|
||||||
iterator_resource->Restore(ctx, serializer.GetReader()));
|
ctx->GetAttr(kExternalStatePolicy, &state_change_option));
|
||||||
|
external_state_policy_ =
|
||||||
|
SerializationContext::ExternalStatePolicy(state_change_option);
|
||||||
}
|
}
|
||||||
};
|
}
|
||||||
|
|
||||||
|
void SerializeIteratorOp::Compute(OpKernelContext* ctx) {
|
||||||
|
const Tensor& resource_handle_t = ctx->input(0);
|
||||||
|
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
|
||||||
|
errors::InvalidArgument("resource_handle must be a scalar"));
|
||||||
|
// Validate that the handle corresponds to a real resource, and
|
||||||
|
// that it is an IteratorResource.
|
||||||
|
IteratorResource* iterator_resource;
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
|
||||||
|
core::ScopedUnref unref_iterator(iterator_resource);
|
||||||
|
IteratorVariantSerializer serializer;
|
||||||
|
SerializationContext::Params params;
|
||||||
|
params.external_state_policy = external_state_policy_;
|
||||||
|
SerializationContext serialization_ctx(params);
|
||||||
|
OP_REQUIRES_OK(ctx, serializer.InitializeFromIterator(&serialization_ctx,
|
||||||
|
iterator_resource));
|
||||||
|
Tensor* serialized_t;
|
||||||
|
OP_REQUIRES_OK(ctx,
|
||||||
|
ctx->allocate_output(0, TensorShape({serializer.NumTensors()}),
|
||||||
|
&serialized_t));
|
||||||
|
OP_REQUIRES_OK(ctx, serializer.Serialize(serialized_t));
|
||||||
|
}
|
||||||
|
|
||||||
|
void DeserializeIteratorOp::Compute(OpKernelContext* ctx) {
|
||||||
|
// Validate that the handle corresponds to a real resource, and
|
||||||
|
// that it is an IteratorResource.
|
||||||
|
IteratorResource* iterator_resource;
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
|
||||||
|
core::ScopedUnref unref_iterator(iterator_resource);
|
||||||
|
const Tensor* serialized_t;
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->input("serialized", &serialized_t));
|
||||||
|
IteratorVariantSerializer serializer;
|
||||||
|
OP_REQUIRES_OK(ctx, serializer.InitFromTensor(serialized_t));
|
||||||
|
OP_REQUIRES_OK(ctx, iterator_resource->Restore(ctx, serializer.GetReader()));
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("Iterator").Device(DEVICE_CPU), IteratorHandleOp);
|
REGISTER_KERNEL_BUILDER(Name("Iterator").Device(DEVICE_CPU), IteratorHandleOp);
|
||||||
REGISTER_KERNEL_BUILDER(Name("IteratorV2").Device(DEVICE_CPU).Priority(2),
|
REGISTER_KERNEL_BUILDER(Name("IteratorV2").Device(DEVICE_CPU).Priority(2),
|
||||||
|
@ -249,6 +249,27 @@ class IteratorFromStringHandleOp : public OpKernel {
|
|||||||
std::vector<PartialTensorShape> output_shapes_;
|
std::vector<PartialTensorShape> output_shapes_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class SerializeIteratorOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
static constexpr const char* const kExternalStatePolicy =
|
||||||
|
"external_state_policy";
|
||||||
|
|
||||||
|
explicit SerializeIteratorOp(OpKernelConstruction* ctx);
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* ctx) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
SerializationContext::ExternalStatePolicy external_state_policy_ =
|
||||||
|
SerializationContext::ExternalStatePolicy::kWarn;
|
||||||
|
};
|
||||||
|
|
||||||
|
class DeserializeIteratorOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit DeserializeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* ctx) override;
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace data
|
} // namespace data
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user