Merge pull request #40118 from zhuzilin:iterator-next-check
PiperOrigin-RevId: 314969403 Change-Id: I11620f4fad44c3a044e7b9643daebd53fb68dc04
This commit is contained in:
commit
e0b451c113
@ -455,6 +455,16 @@ Status RegisterCancellationCallback(CancellationManager* cancellation_manager,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status VerifyTypeMatch(const DataType& expected, const DataType& received,
|
||||
int index) {
|
||||
if (expected != received) {
|
||||
return errors::InvalidArgument("Data type mismatch at component ", index,
|
||||
": expected ", DataTypeString(expected),
|
||||
" but got ", DataTypeString(received), ".");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status VerifyTypesMatch(const DataTypeVector& expected,
|
||||
const DataTypeVector& received) {
|
||||
if (expected.size() != received.size()) {
|
||||
@ -463,12 +473,30 @@ Status VerifyTypesMatch(const DataTypeVector& expected,
|
||||
" types but got ", received.size(), ".");
|
||||
}
|
||||
for (size_t i = 0; i < expected.size(); ++i) {
|
||||
if (expected[i] != received[i]) {
|
||||
return errors::InvalidArgument("Data type mismatch at component ", i,
|
||||
": expected ", DataTypeString(expected[i]),
|
||||
" but got ", DataTypeString(received[i]),
|
||||
".");
|
||||
}
|
||||
TF_RETURN_IF_ERROR(VerifyTypeMatch(expected[i], received[i], i));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status VerifyTypesMatch(const DataTypeVector& expected,
|
||||
const std::vector<Tensor>& received) {
|
||||
if (expected.size() != received.size()) {
|
||||
return errors::InvalidArgument(
|
||||
"Number of components does not match: expected ", expected.size(),
|
||||
" types but got ", received.size(), ".");
|
||||
}
|
||||
for (size_t i = 0; i < expected.size(); ++i) {
|
||||
TF_RETURN_IF_ERROR(VerifyTypeMatch(expected[i], received[i].dtype(), i));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status VerifyShapeCompatible(const PartialTensorShape& expected,
|
||||
const PartialTensorShape& received, int index) {
|
||||
if (!expected.IsCompatibleWith(received)) {
|
||||
return errors::InvalidArgument("Incompatible shapes at component ", index,
|
||||
": expected ", expected.DebugString(),
|
||||
" but got ", received.DebugString(), ".");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -481,12 +509,22 @@ Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
|
||||
" shapes but got ", received.size(), ".");
|
||||
}
|
||||
for (size_t i = 0; i < expected.size(); ++i) {
|
||||
if (!expected[i].IsCompatibleWith(received[i])) {
|
||||
return errors::InvalidArgument("Incompatible shapes at component ", i,
|
||||
": expected ", expected[i].DebugString(),
|
||||
" but got ", received[i].DebugString(),
|
||||
".");
|
||||
}
|
||||
TF_RETURN_IF_ERROR(VerifyShapeCompatible(expected[i], received[i], i));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
|
||||
const std::vector<Tensor>& received) {
|
||||
if (expected.size() != received.size()) {
|
||||
return errors::InvalidArgument(
|
||||
"Number of components does not match: expected ", expected.size(),
|
||||
" shapes but got ", received.size(), ".");
|
||||
}
|
||||
for (size_t i = 0; i < expected.size(); ++i) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
VerifyShapeCompatible(expected[i], received[i].shape(), i));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
@ -94,11 +94,17 @@ Status RegisterCancellationCallback(CancellationManager* cancellation_manager,
|
||||
Status VerifyTypesMatch(const DataTypeVector& expected,
|
||||
const DataTypeVector& received);
|
||||
|
||||
Status VerifyTypesMatch(const DataTypeVector& expected,
|
||||
const std::vector<Tensor>& received);
|
||||
|
||||
// Returns Status::OK() if `expected` and `received` shapes are compatible,
|
||||
// errors::InvalidArgument otherwise.
|
||||
Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
|
||||
const std::vector<PartialTensorShape>& received);
|
||||
|
||||
Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
|
||||
const std::vector<Tensor>& received);
|
||||
|
||||
// Returns a stable hash of the subgraph rooted at the given node.
|
||||
//
|
||||
// NOTE: There is currently no guarantee that the hash of a subgraph will stay
|
||||
|
@ -548,7 +548,10 @@ namespace {
|
||||
class ToSingleElementOp : public HybridAsyncOpKernel {
|
||||
public:
|
||||
explicit ToSingleElementOp(OpKernelConstruction* ctx)
|
||||
: HybridAsyncOpKernel(ctx, "tf_data_to_single_element") {}
|
||||
: HybridAsyncOpKernel(ctx, "tf_data_to_single_element") {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
|
||||
}
|
||||
|
||||
protected:
|
||||
Status DoCompute(OpKernelContext* ctx) override {
|
||||
@ -580,8 +583,9 @@ class ToSingleElementOp : public HybridAsyncOpKernel {
|
||||
if (end_of_sequence) {
|
||||
return errors::InvalidArgument("Dataset was empty.");
|
||||
}
|
||||
TF_RETURN_IF_ERROR(VerifyTypesMatch(output_types_, components));
|
||||
TF_RETURN_IF_ERROR(VerifyShapesCompatible(output_shapes_, components));
|
||||
for (int i = 0; i < components.size(); ++i) {
|
||||
// TODO(mrry): Check that the shapes match the shape attrs.
|
||||
ctx->set_output(i, components[i]);
|
||||
}
|
||||
|
||||
@ -593,6 +597,10 @@ class ToSingleElementOp : public HybridAsyncOpKernel {
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
DataTypeVector output_types_;
|
||||
std::vector<PartialTensorShape> output_shapes_;
|
||||
};
|
||||
|
||||
class ReduceDatasetOp : public HybridAsyncOpKernel {
|
||||
@ -674,33 +682,9 @@ class ReduceDatasetOp : public HybridAsyncOpKernel {
|
||||
std::swap(reduce_func_output, state);
|
||||
}
|
||||
|
||||
if (state.size() != output_types_.size()) {
|
||||
return errors::InvalidArgument(
|
||||
"The number of result elements does not match "
|
||||
"the size of output types: ",
|
||||
state.size(), " vs. ", output_types_.size());
|
||||
}
|
||||
if (state.size() != output_shapes_.size()) {
|
||||
return errors::InvalidArgument(
|
||||
"The number of result elements does not match "
|
||||
"the size of output shapes: ",
|
||||
state.size(), " vs. ", output_shapes_.size());
|
||||
}
|
||||
TF_RETURN_IF_ERROR(VerifyTypesMatch(output_types_, state));
|
||||
TF_RETURN_IF_ERROR(VerifyShapesCompatible(output_shapes_, state));
|
||||
for (size_t i = 0; i < state.size(); ++i) {
|
||||
if (state[i].dtype() != output_types_[i]) {
|
||||
return errors::InvalidArgument(
|
||||
"The result does not match the expected type for "
|
||||
"component ",
|
||||
i, ". Expected: ", DataTypeString(output_types_[i]),
|
||||
". Actual: ", DataTypeString(state[i].dtype()), ".");
|
||||
}
|
||||
if (!output_shapes_[i].IsCompatibleWith(state[i].shape())) {
|
||||
return errors::InvalidArgument(
|
||||
"The result does not match the expected shape for "
|
||||
"component ",
|
||||
i, ". Expected: ", output_shapes_[i].DebugString(),
|
||||
". Actual: ", state[i].shape().DebugString(), ".");
|
||||
}
|
||||
ctx->set_output(i, state[i]);
|
||||
}
|
||||
return Status::OK();
|
||||
@ -917,8 +901,9 @@ Status IteratorGetNextOp::DoCompute(OpKernelContext* ctx) {
|
||||
if (end_of_sequence) {
|
||||
return errors::OutOfRange("End of sequence");
|
||||
}
|
||||
TF_RETURN_IF_ERROR(VerifyTypesMatch(output_types_, components));
|
||||
TF_RETURN_IF_ERROR(VerifyShapesCompatible(output_shapes_, components));
|
||||
for (int i = 0; i < components.size(); ++i) {
|
||||
// TODO(mrry): Check that the shapes match the shape attrs.
|
||||
ctx->set_output(i, components[i]);
|
||||
}
|
||||
return Status::OK();
|
||||
|
@ -216,12 +216,19 @@ class MakeIteratorOp : public HybridAsyncOpKernel {
|
||||
class IteratorGetNextOp : public HybridAsyncOpKernel {
|
||||
public:
|
||||
explicit IteratorGetNextOp(OpKernelConstruction* ctx)
|
||||
: HybridAsyncOpKernel(ctx, "tf_data_iterator_get_next") {}
|
||||
: HybridAsyncOpKernel(ctx, "tf_data_iterator_get_next") {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
|
||||
}
|
||||
|
||||
AsyncOpKernel* AsAsync() override;
|
||||
|
||||
protected:
|
||||
Status DoCompute(OpKernelContext* ctx) override;
|
||||
|
||||
private:
|
||||
DataTypeVector output_types_;
|
||||
std::vector<PartialTensorShape> output_shapes_;
|
||||
};
|
||||
|
||||
class DeleteIteratorOp : public HybridAsyncOpKernel {
|
||||
|
Loading…
Reference in New Issue
Block a user