Merge pull request #40118 from zhuzilin:iterator-next-check

PiperOrigin-RevId: 314969403
Change-Id: I11620f4fad44c3a044e7b9643daebd53fb68dc04
This commit is contained in:
TensorFlower Gardener 2020-06-05 12:08:01 -07:00
commit e0b451c113
4 changed files with 78 additions and 42 deletions

View File

@ -455,6 +455,16 @@ Status RegisterCancellationCallback(CancellationManager* cancellation_manager,
return Status::OK();
}
Status VerifyTypeMatch(const DataType& expected, const DataType& received,
int index) {
if (expected != received) {
return errors::InvalidArgument("Data type mismatch at component ", index,
": expected ", DataTypeString(expected),
" but got ", DataTypeString(received), ".");
}
return Status::OK();
}
Status VerifyTypesMatch(const DataTypeVector& expected,
const DataTypeVector& received) {
if (expected.size() != received.size()) {
@ -463,12 +473,30 @@ Status VerifyTypesMatch(const DataTypeVector& expected,
" types but got ", received.size(), ".");
}
for (size_t i = 0; i < expected.size(); ++i) {
if (expected[i] != received[i]) {
return errors::InvalidArgument("Data type mismatch at component ", i,
": expected ", DataTypeString(expected[i]),
" but got ", DataTypeString(received[i]),
".");
}
TF_RETURN_IF_ERROR(VerifyTypeMatch(expected[i], received[i], i));
}
return Status::OK();
}
Status VerifyTypesMatch(const DataTypeVector& expected,
const std::vector<Tensor>& received) {
if (expected.size() != received.size()) {
return errors::InvalidArgument(
"Number of components does not match: expected ", expected.size(),
" types but got ", received.size(), ".");
}
for (size_t i = 0; i < expected.size(); ++i) {
TF_RETURN_IF_ERROR(VerifyTypeMatch(expected[i], received[i].dtype(), i));
}
return Status::OK();
}
Status VerifyShapeCompatible(const PartialTensorShape& expected,
const PartialTensorShape& received, int index) {
if (!expected.IsCompatibleWith(received)) {
return errors::InvalidArgument("Incompatible shapes at component ", index,
": expected ", expected.DebugString(),
" but got ", received.DebugString(), ".");
}
return Status::OK();
}
@ -481,12 +509,22 @@ Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
" shapes but got ", received.size(), ".");
}
for (size_t i = 0; i < expected.size(); ++i) {
if (!expected[i].IsCompatibleWith(received[i])) {
return errors::InvalidArgument("Incompatible shapes at component ", i,
": expected ", expected[i].DebugString(),
" but got ", received[i].DebugString(),
".");
}
TF_RETURN_IF_ERROR(VerifyShapeCompatible(expected[i], received[i], i));
}
return Status::OK();
}
Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
const std::vector<Tensor>& received) {
if (expected.size() != received.size()) {
return errors::InvalidArgument(
"Number of components does not match: expected ", expected.size(),
" shapes but got ", received.size(), ".");
}
for (size_t i = 0; i < expected.size(); ++i) {
TF_RETURN_IF_ERROR(
VerifyShapeCompatible(expected[i], received[i].shape(), i));
}
return Status::OK();

View File

@ -94,11 +94,17 @@ Status RegisterCancellationCallback(CancellationManager* cancellation_manager,
Status VerifyTypesMatch(const DataTypeVector& expected,
const DataTypeVector& received);
Status VerifyTypesMatch(const DataTypeVector& expected,
const std::vector<Tensor>& received);
// Returns Status::OK() if `expected` and `received` shapes are compatible,
// errors::InvalidArgument otherwise.
Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
const std::vector<PartialTensorShape>& received);
Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
const std::vector<Tensor>& received);
// Returns a stable hash of the subgraph rooted at the given node.
//
// NOTE: There is currently no guarantee that the hash of a subgraph will stay

View File

@ -548,7 +548,10 @@ namespace {
class ToSingleElementOp : public HybridAsyncOpKernel {
public:
explicit ToSingleElementOp(OpKernelConstruction* ctx)
: HybridAsyncOpKernel(ctx, "tf_data_to_single_element") {}
: HybridAsyncOpKernel(ctx, "tf_data_to_single_element") {
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
}
protected:
Status DoCompute(OpKernelContext* ctx) override {
@ -580,8 +583,9 @@ class ToSingleElementOp : public HybridAsyncOpKernel {
if (end_of_sequence) {
return errors::InvalidArgument("Dataset was empty.");
}
TF_RETURN_IF_ERROR(VerifyTypesMatch(output_types_, components));
TF_RETURN_IF_ERROR(VerifyShapesCompatible(output_shapes_, components));
for (int i = 0; i < components.size(); ++i) {
// TODO(mrry): Check that the shapes match the shape attrs.
ctx->set_output(i, components[i]);
}
@ -593,6 +597,10 @@ class ToSingleElementOp : public HybridAsyncOpKernel {
}
return Status::OK();
}
private:
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
};
class ReduceDatasetOp : public HybridAsyncOpKernel {
@ -674,33 +682,9 @@ class ReduceDatasetOp : public HybridAsyncOpKernel {
std::swap(reduce_func_output, state);
}
if (state.size() != output_types_.size()) {
return errors::InvalidArgument(
"The number of result elements does not match "
"the size of output types: ",
state.size(), " vs. ", output_types_.size());
}
if (state.size() != output_shapes_.size()) {
return errors::InvalidArgument(
"The number of result elements does not match "
"the size of output shapes: ",
state.size(), " vs. ", output_shapes_.size());
}
TF_RETURN_IF_ERROR(VerifyTypesMatch(output_types_, state));
TF_RETURN_IF_ERROR(VerifyShapesCompatible(output_shapes_, state));
for (size_t i = 0; i < state.size(); ++i) {
if (state[i].dtype() != output_types_[i]) {
return errors::InvalidArgument(
"The result does not match the expected type for "
"component ",
i, ". Expected: ", DataTypeString(output_types_[i]),
". Actual: ", DataTypeString(state[i].dtype()), ".");
}
if (!output_shapes_[i].IsCompatibleWith(state[i].shape())) {
return errors::InvalidArgument(
"The result does not match the expected shape for "
"component ",
i, ". Expected: ", output_shapes_[i].DebugString(),
". Actual: ", state[i].shape().DebugString(), ".");
}
ctx->set_output(i, state[i]);
}
return Status::OK();
@ -917,8 +901,9 @@ Status IteratorGetNextOp::DoCompute(OpKernelContext* ctx) {
if (end_of_sequence) {
return errors::OutOfRange("End of sequence");
}
TF_RETURN_IF_ERROR(VerifyTypesMatch(output_types_, components));
TF_RETURN_IF_ERROR(VerifyShapesCompatible(output_shapes_, components));
for (int i = 0; i < components.size(); ++i) {
// TODO(mrry): Check that the shapes match the shape attrs.
ctx->set_output(i, components[i]);
}
return Status::OK();

View File

@ -216,12 +216,19 @@ class MakeIteratorOp : public HybridAsyncOpKernel {
class IteratorGetNextOp : public HybridAsyncOpKernel {
public:
explicit IteratorGetNextOp(OpKernelConstruction* ctx)
: HybridAsyncOpKernel(ctx, "tf_data_iterator_get_next") {}
: HybridAsyncOpKernel(ctx, "tf_data_iterator_get_next") {
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
}
AsyncOpKernel* AsAsync() override;
protected:
Status DoCompute(OpKernelContext* ctx) override;
private:
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
};
class DeleteIteratorOp : public HybridAsyncOpKernel {