Throw a warning instead of error when CheckOpKernelInput() fails because the number of kernel input tensors is allowed to mismatch with input names
This commit is contained in:
parent
a8f9799bda
commit
4aad621976
@ -25,7 +25,6 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint"
|
|
||||||
#include "tensorflow/core/common_runtime/device.h"
|
#include "tensorflow/core/common_runtime/device.h"
|
||||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||||
@ -83,6 +82,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/public/session_options.h"
|
#include "tensorflow/core/public/session_options.h"
|
||||||
#include "tensorflow/core/public/version.h"
|
#include "tensorflow/core/public/version.h"
|
||||||
#include "tensorflow/core/util/tensor_slice_reader_cache.h"
|
#include "tensorflow/core/util/tensor_slice_reader_cache.h"
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace data {
|
namespace data {
|
||||||
@ -321,7 +321,10 @@ Status DatasetOpsTestBase::CreateDatasetContext(
|
|||||||
gtl::InlinedVector<TensorValue, 4>* const inputs,
|
gtl::InlinedVector<TensorValue, 4>* const inputs,
|
||||||
std::unique_ptr<OpKernelContext::Params>* dataset_context_params,
|
std::unique_ptr<OpKernelContext::Params>* dataset_context_params,
|
||||||
std::unique_ptr<OpKernelContext>* dataset_context) {
|
std::unique_ptr<OpKernelContext>* dataset_context) {
|
||||||
TF_RETURN_IF_ERROR(CheckOpKernelInput(*dateset_kernel, *inputs));
|
Status status = CheckOpKernelInput(*dateset_kernel, *inputs);
|
||||||
|
if (!status.ok()) {
|
||||||
|
VLOG(0) << "WARNING: " << status.ToString();
|
||||||
|
}
|
||||||
TF_RETURN_IF_ERROR(CreateOpKernelContext(
|
TF_RETURN_IF_ERROR(CreateOpKernelContext(
|
||||||
dateset_kernel, inputs, dataset_context_params, dataset_context));
|
dateset_kernel, inputs, dataset_context_params, dataset_context));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -529,10 +532,10 @@ Status DatasetOpsTestBase::CreateSerializationContext(
|
|||||||
|
|
||||||
Status DatasetOpsTestBase::CheckOpKernelInput(
|
Status DatasetOpsTestBase::CheckOpKernelInput(
|
||||||
const OpKernel& kernel, const gtl::InlinedVector<TensorValue, 4>& inputs) {
|
const OpKernel& kernel, const gtl::InlinedVector<TensorValue, 4>& inputs) {
|
||||||
if (kernel.input_types().size() != inputs.size()) {
|
if (kernel.num_inputs() != inputs.size()) {
|
||||||
return errors::Internal("The number of input elements should be ",
|
return errors::InvalidArgument("The number of input elements should be ",
|
||||||
kernel.input_types().size(),
|
kernel.num_inputs(),
|
||||||
", but got: ", inputs.size());
|
", but got: ", inputs.size());
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -254,7 +254,7 @@ class DirectedInterleaveDatasetOp::Dataset : public DatasetBase {
|
|||||||
const DatasetBase* const selector_input_;
|
const DatasetBase* const selector_input_;
|
||||||
const std::vector<DatasetBase*> data_inputs_;
|
const std::vector<DatasetBase*> data_inputs_;
|
||||||
std::vector<PartialTensorShape> output_shapes_;
|
std::vector<PartialTensorShape> output_shapes_;
|
||||||
}; // namespace experimental
|
};
|
||||||
|
|
||||||
DirectedInterleaveDatasetOp::DirectedInterleaveDatasetOp(
|
DirectedInterleaveDatasetOp::DirectedInterleaveDatasetOp(
|
||||||
OpKernelConstruction* ctx)
|
OpKernelConstruction* ctx)
|
||||||
|
Loading…
Reference in New Issue
Block a user