Merge pull request #40118 from zhuzilin:iterator-next-check
PiperOrigin-RevId: 314969403 Change-Id: I11620f4fad44c3a044e7b9643daebd53fb68dc04
This commit is contained in:
commit
e0b451c113
@ -455,6 +455,16 @@ Status RegisterCancellationCallback(CancellationManager* cancellation_manager,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status VerifyTypeMatch(const DataType& expected, const DataType& received,
|
||||||
|
int index) {
|
||||||
|
if (expected != received) {
|
||||||
|
return errors::InvalidArgument("Data type mismatch at component ", index,
|
||||||
|
": expected ", DataTypeString(expected),
|
||||||
|
" but got ", DataTypeString(received), ".");
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
Status VerifyTypesMatch(const DataTypeVector& expected,
|
Status VerifyTypesMatch(const DataTypeVector& expected,
|
||||||
const DataTypeVector& received) {
|
const DataTypeVector& received) {
|
||||||
if (expected.size() != received.size()) {
|
if (expected.size() != received.size()) {
|
||||||
@ -463,12 +473,30 @@ Status VerifyTypesMatch(const DataTypeVector& expected,
|
|||||||
" types but got ", received.size(), ".");
|
" types but got ", received.size(), ".");
|
||||||
}
|
}
|
||||||
for (size_t i = 0; i < expected.size(); ++i) {
|
for (size_t i = 0; i < expected.size(); ++i) {
|
||||||
if (expected[i] != received[i]) {
|
TF_RETURN_IF_ERROR(VerifyTypeMatch(expected[i], received[i], i));
|
||||||
return errors::InvalidArgument("Data type mismatch at component ", i,
|
|
||||||
": expected ", DataTypeString(expected[i]),
|
|
||||||
" but got ", DataTypeString(received[i]),
|
|
||||||
".");
|
|
||||||
}
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status VerifyTypesMatch(const DataTypeVector& expected,
|
||||||
|
const std::vector<Tensor>& received) {
|
||||||
|
if (expected.size() != received.size()) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Number of components does not match: expected ", expected.size(),
|
||||||
|
" types but got ", received.size(), ".");
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < expected.size(); ++i) {
|
||||||
|
TF_RETURN_IF_ERROR(VerifyTypeMatch(expected[i], received[i].dtype(), i));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status VerifyShapeCompatible(const PartialTensorShape& expected,
|
||||||
|
const PartialTensorShape& received, int index) {
|
||||||
|
if (!expected.IsCompatibleWith(received)) {
|
||||||
|
return errors::InvalidArgument("Incompatible shapes at component ", index,
|
||||||
|
": expected ", expected.DebugString(),
|
||||||
|
" but got ", received.DebugString(), ".");
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -481,12 +509,22 @@ Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
|
|||||||
" shapes but got ", received.size(), ".");
|
" shapes but got ", received.size(), ".");
|
||||||
}
|
}
|
||||||
for (size_t i = 0; i < expected.size(); ++i) {
|
for (size_t i = 0; i < expected.size(); ++i) {
|
||||||
if (!expected[i].IsCompatibleWith(received[i])) {
|
TF_RETURN_IF_ERROR(VerifyShapeCompatible(expected[i], received[i], i));
|
||||||
return errors::InvalidArgument("Incompatible shapes at component ", i,
|
|
||||||
": expected ", expected[i].DebugString(),
|
|
||||||
" but got ", received[i].DebugString(),
|
|
||||||
".");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
|
||||||
|
const std::vector<Tensor>& received) {
|
||||||
|
if (expected.size() != received.size()) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Number of components does not match: expected ", expected.size(),
|
||||||
|
" shapes but got ", received.size(), ".");
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < expected.size(); ++i) {
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
VerifyShapeCompatible(expected[i], received[i].shape(), i));
|
||||||
}
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
@ -94,11 +94,17 @@ Status RegisterCancellationCallback(CancellationManager* cancellation_manager,
|
|||||||
Status VerifyTypesMatch(const DataTypeVector& expected,
|
Status VerifyTypesMatch(const DataTypeVector& expected,
|
||||||
const DataTypeVector& received);
|
const DataTypeVector& received);
|
||||||
|
|
||||||
|
Status VerifyTypesMatch(const DataTypeVector& expected,
|
||||||
|
const std::vector<Tensor>& received);
|
||||||
|
|
||||||
// Returns Status::OK() if `expected` and `received` shapes are compatible,
|
// Returns Status::OK() if `expected` and `received` shapes are compatible,
|
||||||
// errors::InvalidArgument otherwise.
|
// errors::InvalidArgument otherwise.
|
||||||
Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
|
Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
|
||||||
const std::vector<PartialTensorShape>& received);
|
const std::vector<PartialTensorShape>& received);
|
||||||
|
|
||||||
|
Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
|
||||||
|
const std::vector<Tensor>& received);
|
||||||
|
|
||||||
// Returns a stable hash of the subgraph rooted at the given node.
|
// Returns a stable hash of the subgraph rooted at the given node.
|
||||||
//
|
//
|
||||||
// NOTE: There is currently no guarantee that the hash of a subgraph will stay
|
// NOTE: There is currently no guarantee that the hash of a subgraph will stay
|
||||||
|
@ -548,7 +548,10 @@ namespace {
|
|||||||
class ToSingleElementOp : public HybridAsyncOpKernel {
|
class ToSingleElementOp : public HybridAsyncOpKernel {
|
||||||
public:
|
public:
|
||||||
explicit ToSingleElementOp(OpKernelConstruction* ctx)
|
explicit ToSingleElementOp(OpKernelConstruction* ctx)
|
||||||
: HybridAsyncOpKernel(ctx, "tf_data_to_single_element") {}
|
: HybridAsyncOpKernel(ctx, "tf_data_to_single_element") {
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
|
||||||
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
Status DoCompute(OpKernelContext* ctx) override {
|
Status DoCompute(OpKernelContext* ctx) override {
|
||||||
@ -580,8 +583,9 @@ class ToSingleElementOp : public HybridAsyncOpKernel {
|
|||||||
if (end_of_sequence) {
|
if (end_of_sequence) {
|
||||||
return errors::InvalidArgument("Dataset was empty.");
|
return errors::InvalidArgument("Dataset was empty.");
|
||||||
}
|
}
|
||||||
|
TF_RETURN_IF_ERROR(VerifyTypesMatch(output_types_, components));
|
||||||
|
TF_RETURN_IF_ERROR(VerifyShapesCompatible(output_shapes_, components));
|
||||||
for (int i = 0; i < components.size(); ++i) {
|
for (int i = 0; i < components.size(); ++i) {
|
||||||
// TODO(mrry): Check that the shapes match the shape attrs.
|
|
||||||
ctx->set_output(i, components[i]);
|
ctx->set_output(i, components[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -593,6 +597,10 @@ class ToSingleElementOp : public HybridAsyncOpKernel {
|
|||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
DataTypeVector output_types_;
|
||||||
|
std::vector<PartialTensorShape> output_shapes_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class ReduceDatasetOp : public HybridAsyncOpKernel {
|
class ReduceDatasetOp : public HybridAsyncOpKernel {
|
||||||
@ -674,33 +682,9 @@ class ReduceDatasetOp : public HybridAsyncOpKernel {
|
|||||||
std::swap(reduce_func_output, state);
|
std::swap(reduce_func_output, state);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (state.size() != output_types_.size()) {
|
TF_RETURN_IF_ERROR(VerifyTypesMatch(output_types_, state));
|
||||||
return errors::InvalidArgument(
|
TF_RETURN_IF_ERROR(VerifyShapesCompatible(output_shapes_, state));
|
||||||
"The number of result elements does not match "
|
|
||||||
"the size of output types: ",
|
|
||||||
state.size(), " vs. ", output_types_.size());
|
|
||||||
}
|
|
||||||
if (state.size() != output_shapes_.size()) {
|
|
||||||
return errors::InvalidArgument(
|
|
||||||
"The number of result elements does not match "
|
|
||||||
"the size of output shapes: ",
|
|
||||||
state.size(), " vs. ", output_shapes_.size());
|
|
||||||
}
|
|
||||||
for (size_t i = 0; i < state.size(); ++i) {
|
for (size_t i = 0; i < state.size(); ++i) {
|
||||||
if (state[i].dtype() != output_types_[i]) {
|
|
||||||
return errors::InvalidArgument(
|
|
||||||
"The result does not match the expected type for "
|
|
||||||
"component ",
|
|
||||||
i, ". Expected: ", DataTypeString(output_types_[i]),
|
|
||||||
". Actual: ", DataTypeString(state[i].dtype()), ".");
|
|
||||||
}
|
|
||||||
if (!output_shapes_[i].IsCompatibleWith(state[i].shape())) {
|
|
||||||
return errors::InvalidArgument(
|
|
||||||
"The result does not match the expected shape for "
|
|
||||||
"component ",
|
|
||||||
i, ". Expected: ", output_shapes_[i].DebugString(),
|
|
||||||
". Actual: ", state[i].shape().DebugString(), ".");
|
|
||||||
}
|
|
||||||
ctx->set_output(i, state[i]);
|
ctx->set_output(i, state[i]);
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -917,8 +901,9 @@ Status IteratorGetNextOp::DoCompute(OpKernelContext* ctx) {
|
|||||||
if (end_of_sequence) {
|
if (end_of_sequence) {
|
||||||
return errors::OutOfRange("End of sequence");
|
return errors::OutOfRange("End of sequence");
|
||||||
}
|
}
|
||||||
|
TF_RETURN_IF_ERROR(VerifyTypesMatch(output_types_, components));
|
||||||
|
TF_RETURN_IF_ERROR(VerifyShapesCompatible(output_shapes_, components));
|
||||||
for (int i = 0; i < components.size(); ++i) {
|
for (int i = 0; i < components.size(); ++i) {
|
||||||
// TODO(mrry): Check that the shapes match the shape attrs.
|
|
||||||
ctx->set_output(i, components[i]);
|
ctx->set_output(i, components[i]);
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
@ -216,12 +216,19 @@ class MakeIteratorOp : public HybridAsyncOpKernel {
|
|||||||
class IteratorGetNextOp : public HybridAsyncOpKernel {
|
class IteratorGetNextOp : public HybridAsyncOpKernel {
|
||||||
public:
|
public:
|
||||||
explicit IteratorGetNextOp(OpKernelConstruction* ctx)
|
explicit IteratorGetNextOp(OpKernelConstruction* ctx)
|
||||||
: HybridAsyncOpKernel(ctx, "tf_data_iterator_get_next") {}
|
: HybridAsyncOpKernel(ctx, "tf_data_iterator_get_next") {
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
|
||||||
|
}
|
||||||
|
|
||||||
AsyncOpKernel* AsAsync() override;
|
AsyncOpKernel* AsAsync() override;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
Status DoCompute(OpKernelContext* ctx) override;
|
Status DoCompute(OpKernelContext* ctx) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
DataTypeVector output_types_;
|
||||||
|
std::vector<PartialTensorShape> output_shapes_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class DeleteIteratorOp : public HybridAsyncOpKernel {
|
class DeleteIteratorOp : public HybridAsyncOpKernel {
|
||||||
|
Loading…
Reference in New Issue
Block a user