Throw a warning instead of error when CheckOpKernelInput() fails because the number of kernel input tensors is allowed to mismatch with input names

This commit is contained in:
feihugis 2020-03-18 00:00:21 -05:00
parent a8f9799bda
commit 4aad621976
2 changed files with 10 additions and 7 deletions

View File

@ -25,7 +25,6 @@ limitations under the License.
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint"
#include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/device_mgr.h"
@ -83,6 +82,7 @@ limitations under the License.
#include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/public/version.h" #include "tensorflow/core/public/version.h"
#include "tensorflow/core/util/tensor_slice_reader_cache.h" #include "tensorflow/core/util/tensor_slice_reader_cache.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint"
namespace tensorflow { namespace tensorflow {
namespace data { namespace data {
@ -321,7 +321,10 @@ Status DatasetOpsTestBase::CreateDatasetContext(
gtl::InlinedVector<TensorValue, 4>* const inputs, gtl::InlinedVector<TensorValue, 4>* const inputs,
std::unique_ptr<OpKernelContext::Params>* dataset_context_params, std::unique_ptr<OpKernelContext::Params>* dataset_context_params,
std::unique_ptr<OpKernelContext>* dataset_context) { std::unique_ptr<OpKernelContext>* dataset_context) {
TF_RETURN_IF_ERROR(CheckOpKernelInput(*dateset_kernel, *inputs)); Status status = CheckOpKernelInput(*dateset_kernel, *inputs);
if (!status.ok()) {
VLOG(0) << "WARNING: " << status.ToString();
}
TF_RETURN_IF_ERROR(CreateOpKernelContext( TF_RETURN_IF_ERROR(CreateOpKernelContext(
dateset_kernel, inputs, dataset_context_params, dataset_context)); dateset_kernel, inputs, dataset_context_params, dataset_context));
return Status::OK(); return Status::OK();
@ -529,10 +532,10 @@ Status DatasetOpsTestBase::CreateSerializationContext(
Status DatasetOpsTestBase::CheckOpKernelInput( Status DatasetOpsTestBase::CheckOpKernelInput(
const OpKernel& kernel, const gtl::InlinedVector<TensorValue, 4>& inputs) { const OpKernel& kernel, const gtl::InlinedVector<TensorValue, 4>& inputs) {
if (kernel.input_types().size() != inputs.size()) { if (kernel.num_inputs() != inputs.size()) {
return errors::Internal("The number of input elements should be ", return errors::InvalidArgument("The number of input elements should be ",
kernel.input_types().size(), kernel.num_inputs(),
", but got: ", inputs.size()); ", but got: ", inputs.size());
} }
return Status::OK(); return Status::OK();
} }

View File

@ -254,7 +254,7 @@ class DirectedInterleaveDatasetOp::Dataset : public DatasetBase {
const DatasetBase* const selector_input_; const DatasetBase* const selector_input_;
const std::vector<DatasetBase*> data_inputs_; const std::vector<DatasetBase*> data_inputs_;
std::vector<PartialTensorShape> output_shapes_; std::vector<PartialTensorShape> output_shapes_;
}; // namespace experimental };
DirectedInterleaveDatasetOp::DirectedInterleaveDatasetOp( DirectedInterleaveDatasetOp::DirectedInterleaveDatasetOp(
OpKernelConstruction* ctx) OpKernelConstruction* ctx)