[tf.data] Simplify error handling in async kernels that use a BackgroundWorker.
This CL moves the main logic of three async tf.data kernels to a separate DoCompute() method that returns a Status. This enables us to use RAII for cleaning up objects in error cases, without having to manipulate the `done` callback. PiperOrigin-RevId: 282584164 Change-Id: I7c7a27c8f404826146adcf29c8b4a62912155be5
This commit is contained in:
parent
1b5e0b75ce
commit
4855dd694f
@ -48,93 +48,64 @@ class ToTFRecordOp : public AsyncOpKernel {
|
||||
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
|
||||
// The call to `iterator->GetNext()` may block and depend on an inter-op
|
||||
// thread pool thread, so we issue the call using a background thread.
|
||||
background_worker_.Schedule(std::bind(
|
||||
[this, ctx](std::function<void()>& done) {
|
||||
tstring filename;
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx, ParseScalarArgument<tstring>(ctx, "filename", &filename),
|
||||
done);
|
||||
tstring compression_type;
|
||||
OP_REQUIRES_OK_ASYNC(ctx,
|
||||
ParseScalarArgument<tstring>(
|
||||
ctx, "compression_type", &compression_type),
|
||||
done);
|
||||
std::unique_ptr<WritableFile> file;
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx, ctx->env()->NewWritableFile(filename, &file), done);
|
||||
auto writer = absl::make_unique<io::RecordWriter>(
|
||||
file.get(), io::RecordWriterOptions::CreateRecordWriterOptions(
|
||||
compression_type));
|
||||
|
||||
DatasetBase* dataset;
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset), done);
|
||||
|
||||
IteratorContext::Params params(ctx);
|
||||
FunctionHandleCache function_handle_cache(params.flr);
|
||||
params.function_handle_cache = &function_handle_cache;
|
||||
ResourceMgr resource_mgr;
|
||||
params.resource_mgr = &resource_mgr;
|
||||
CancellationManager cancellation_manager;
|
||||
params.cancellation_manager = &cancellation_manager;
|
||||
std::function<void()> deregister_fn;
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx,
|
||||
RegisterCancellationCallback(
|
||||
ctx->cancellation_manager(),
|
||||
[cm = params.cancellation_manager]() { cm->StartCancel(); },
|
||||
&deregister_fn),
|
||||
done);
|
||||
|
||||
// Update the `done` callback to deregister the cancellation callback.
|
||||
done = std::bind(
|
||||
[](const std::function<void()>& done,
|
||||
const std::function<void()>& deregister_fn) {
|
||||
deregister_fn();
|
||||
done();
|
||||
},
|
||||
std::move(done), std::move(deregister_fn));
|
||||
|
||||
IteratorContext iter_ctx(std::move(params));
|
||||
std::unique_ptr<IteratorBase> iterator;
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx,
|
||||
dataset->MakeIterator(&iter_ctx, "ToTFRecordOpIterator",
|
||||
&iterator),
|
||||
done);
|
||||
|
||||
// Update the `done` callback to destroy the iterator before calling
|
||||
// the actual callback to avoid destruction races.
|
||||
IteratorBase* raw_iterator = iterator.release();
|
||||
done = std::bind(
|
||||
[raw_iterator](const std::function<void()>& done) {
|
||||
delete raw_iterator;
|
||||
done();
|
||||
},
|
||||
std::move(done));
|
||||
|
||||
std::vector<Tensor> components;
|
||||
components.reserve(dataset->output_dtypes().size());
|
||||
bool end_of_sequence;
|
||||
do {
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx,
|
||||
raw_iterator->GetNext(&iter_ctx, &components, &end_of_sequence),
|
||||
done);
|
||||
|
||||
if (!end_of_sequence) {
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx, writer->WriteRecord(components[0].scalar<tstring>()()),
|
||||
done);
|
||||
}
|
||||
components.clear();
|
||||
} while (!end_of_sequence);
|
||||
done();
|
||||
},
|
||||
std::move(done)));
|
||||
background_worker_.Schedule([this, ctx, done = std::move(done)]() {
|
||||
OP_REQUIRES_OK_ASYNC(ctx, DoCompute(ctx), done);
|
||||
done();
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
Status DoCompute(OpKernelContext* ctx) {
|
||||
tstring filename;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ParseScalarArgument<tstring>(ctx, "filename", &filename));
|
||||
tstring compression_type;
|
||||
TF_RETURN_IF_ERROR(ParseScalarArgument<tstring>(ctx, "compression_type",
|
||||
&compression_type));
|
||||
std::unique_ptr<WritableFile> file;
|
||||
TF_RETURN_IF_ERROR(ctx->env()->NewWritableFile(filename, &file));
|
||||
auto writer = absl::make_unique<io::RecordWriter>(
|
||||
file.get(),
|
||||
io::RecordWriterOptions::CreateRecordWriterOptions(compression_type));
|
||||
|
||||
DatasetBase* dataset;
|
||||
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(ctx->input(0), &dataset));
|
||||
|
||||
IteratorContext::Params params(ctx);
|
||||
FunctionHandleCache function_handle_cache(params.flr);
|
||||
params.function_handle_cache = &function_handle_cache;
|
||||
ResourceMgr resource_mgr;
|
||||
params.resource_mgr = &resource_mgr;
|
||||
CancellationManager cancellation_manager;
|
||||
params.cancellation_manager = &cancellation_manager;
|
||||
std::function<void()> deregister_fn;
|
||||
TF_RETURN_IF_ERROR(RegisterCancellationCallback(
|
||||
ctx->cancellation_manager(),
|
||||
[cm = params.cancellation_manager]() { cm->StartCancel(); },
|
||||
&deregister_fn));
|
||||
auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
|
||||
|
||||
IteratorContext iter_ctx(std::move(params));
|
||||
std::unique_ptr<IteratorBase> iterator;
|
||||
TF_RETURN_IF_ERROR(
|
||||
dataset->MakeIterator(&iter_ctx, "ToTFRecordOpIterator", &iterator));
|
||||
|
||||
std::vector<Tensor> components;
|
||||
components.reserve(dataset->output_dtypes().size());
|
||||
bool end_of_sequence;
|
||||
do {
|
||||
TF_RETURN_IF_ERROR(
|
||||
iterator->GetNext(&iter_ctx, &components, &end_of_sequence));
|
||||
|
||||
if (!end_of_sequence) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteRecord(components[0].scalar<tstring>()()));
|
||||
}
|
||||
components.clear();
|
||||
} while (!end_of_sequence);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
BackgroundWorker background_worker_;
|
||||
};
|
||||
|
||||
|
@ -453,96 +453,60 @@ class ToSingleElementOp : public AsyncOpKernel {
|
||||
// The call to `iterator->GetNext()` may block and depend on an
|
||||
// inter-op thread pool thread, so we issue the call from the
|
||||
// owned thread pool.
|
||||
background_worker_.Schedule(std::bind(
|
||||
[ctx](std::function<void()>& done) {
|
||||
DatasetBase* dataset;
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset), done);
|
||||
|
||||
IteratorContext::Params params(ctx);
|
||||
FunctionHandleCache function_handle_cache(params.flr);
|
||||
params.function_handle_cache = &function_handle_cache;
|
||||
ResourceMgr resource_mgr;
|
||||
params.resource_mgr = &resource_mgr;
|
||||
CancellationManager cancellation_manager;
|
||||
params.cancellation_manager = &cancellation_manager;
|
||||
std::function<void()> deregister_fn;
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx,
|
||||
RegisterCancellationCallback(
|
||||
ctx->cancellation_manager(),
|
||||
[cm = params.cancellation_manager]() { cm->StartCancel(); },
|
||||
&deregister_fn),
|
||||
done);
|
||||
|
||||
// Update the `done` callback to deregister the cancellation callback.
|
||||
done = std::bind(
|
||||
[](const std::function<void()>& done,
|
||||
const std::function<void()>& deregister_fn) {
|
||||
deregister_fn();
|
||||
done();
|
||||
},
|
||||
std::move(done), std::move(deregister_fn));
|
||||
|
||||
IteratorContext iter_ctx(std::move(params));
|
||||
std::unique_ptr<IteratorBase> iterator;
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx,
|
||||
dataset->MakeIterator(&iter_ctx, "SingleElementIterator",
|
||||
&iterator),
|
||||
done);
|
||||
|
||||
// Update the `done` callback to destroy the iterator before calling
|
||||
// the actual callback to avoid destruction races.
|
||||
IteratorBase* raw_iterator = iterator.release();
|
||||
done = std::bind(
|
||||
[raw_iterator](const std::function<void()>& done) {
|
||||
delete raw_iterator;
|
||||
done();
|
||||
},
|
||||
std::move(done));
|
||||
|
||||
std::vector<Tensor> components;
|
||||
components.reserve(dataset->output_dtypes().size());
|
||||
bool end_of_sequence = false;
|
||||
|
||||
Status s =
|
||||
raw_iterator->GetNext(&iter_ctx, &components, &end_of_sequence);
|
||||
if (!s.ok()) {
|
||||
ctx->SetStatus(s);
|
||||
done();
|
||||
return;
|
||||
}
|
||||
if (end_of_sequence) {
|
||||
ctx->SetStatus(errors::InvalidArgument("Dataset was empty."));
|
||||
done();
|
||||
return;
|
||||
}
|
||||
for (int i = 0; i < components.size(); ++i) {
|
||||
// TODO(mrry): Check that the shapes match the shape attrs.
|
||||
ctx->set_output(i, components[i]);
|
||||
}
|
||||
|
||||
components.clear();
|
||||
s.Update(
|
||||
raw_iterator->GetNext(&iter_ctx, &components, &end_of_sequence));
|
||||
if (!s.ok()) {
|
||||
ctx->SetStatus(s);
|
||||
done();
|
||||
return;
|
||||
}
|
||||
if (!end_of_sequence) {
|
||||
ctx->SetStatus(
|
||||
errors::InvalidArgument("Dataset had more than one element."));
|
||||
done();
|
||||
return;
|
||||
}
|
||||
done();
|
||||
},
|
||||
std::move(done)));
|
||||
background_worker_.Schedule([this, ctx, done = std::move(done)]() {
|
||||
OP_REQUIRES_OK_ASYNC(ctx, DoCompute(ctx), done);
|
||||
done();
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
Status DoCompute(OpKernelContext* ctx) {
|
||||
DatasetBase* dataset;
|
||||
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(ctx->input(0), &dataset));
|
||||
|
||||
IteratorContext::Params params(ctx);
|
||||
FunctionHandleCache function_handle_cache(params.flr);
|
||||
params.function_handle_cache = &function_handle_cache;
|
||||
ResourceMgr resource_mgr;
|
||||
params.resource_mgr = &resource_mgr;
|
||||
CancellationManager cancellation_manager;
|
||||
params.cancellation_manager = &cancellation_manager;
|
||||
std::function<void()> deregister_fn;
|
||||
TF_RETURN_IF_ERROR(RegisterCancellationCallback(
|
||||
ctx->cancellation_manager(),
|
||||
[cm = params.cancellation_manager]() { cm->StartCancel(); },
|
||||
&deregister_fn));
|
||||
auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
|
||||
|
||||
IteratorContext iter_ctx(std::move(params));
|
||||
std::unique_ptr<IteratorBase> iterator;
|
||||
TF_RETURN_IF_ERROR(
|
||||
dataset->MakeIterator(&iter_ctx, "SingleElementIterator", &iterator));
|
||||
|
||||
std::vector<Tensor> components;
|
||||
components.reserve(dataset->output_dtypes().size());
|
||||
bool end_of_sequence = false;
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
iterator->GetNext(&iter_ctx, &components, &end_of_sequence));
|
||||
|
||||
if (end_of_sequence) {
|
||||
return errors::InvalidArgument("Dataset was empty.");
|
||||
}
|
||||
for (int i = 0; i < components.size(); ++i) {
|
||||
// TODO(mrry): Check that the shapes match the shape attrs.
|
||||
ctx->set_output(i, components[i]);
|
||||
}
|
||||
|
||||
components.clear();
|
||||
TF_RETURN_IF_ERROR(
|
||||
iterator->GetNext(&iter_ctx, &components, &end_of_sequence));
|
||||
if (!end_of_sequence) {
|
||||
return errors::InvalidArgument("Dataset had more than one element.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
BackgroundWorker background_worker_;
|
||||
};
|
||||
|
||||
@ -565,154 +529,113 @@ class ReduceDatasetOp : public AsyncOpKernel {
|
||||
// The call to `iterator->GetNext()` may block and depend on an
|
||||
// inter-op thread pool thread, so we issue the call from the
|
||||
// owned thread pool.
|
||||
background_worker_.Schedule(std::bind(
|
||||
[this, ctx](std::function<void()>& done) {
|
||||
DatasetBase* dataset;
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset), done);
|
||||
OpInputList inputs;
|
||||
OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("initial_state", &inputs),
|
||||
done);
|
||||
std::vector<Tensor> state(inputs.begin(), inputs.end());
|
||||
|
||||
std::unique_ptr<CapturedFunction> captured_func;
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx,
|
||||
CapturedFunction::Create(ctx, func_metadata_, "other_arguments",
|
||||
&captured_func),
|
||||
done);
|
||||
|
||||
IteratorContext::Params params(ctx);
|
||||
auto function_handle_cache =
|
||||
absl::make_unique<FunctionHandleCache>(params.flr);
|
||||
params.function_handle_cache = function_handle_cache.get();
|
||||
ResourceMgr resource_mgr;
|
||||
params.resource_mgr = &resource_mgr;
|
||||
CancellationManager cancellation_manager;
|
||||
params.cancellation_manager = &cancellation_manager;
|
||||
std::function<void()> deregister_fn;
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx,
|
||||
RegisterCancellationCallback(
|
||||
ctx->cancellation_manager(),
|
||||
[cm = params.cancellation_manager]() { cm->StartCancel(); },
|
||||
&deregister_fn),
|
||||
done);
|
||||
|
||||
// Update the `done` callback to deregister the cancellation callback.
|
||||
done = std::bind(
|
||||
[](const std::function<void()>& done,
|
||||
const std::function<void()>& deregister_fn) {
|
||||
deregister_fn();
|
||||
done();
|
||||
},
|
||||
std::move(done), std::move(deregister_fn));
|
||||
|
||||
IteratorContext iter_ctx(std::move(params));
|
||||
std::unique_ptr<InstantiatedCapturedFunction>
|
||||
instantiated_captured_func;
|
||||
OP_REQUIRES_OK_ASYNC(ctx,
|
||||
captured_func->Instantiate(
|
||||
&iter_ctx, &instantiated_captured_func),
|
||||
done);
|
||||
|
||||
std::unique_ptr<IteratorBase> iterator;
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx,
|
||||
dataset->MakeIterator(&iter_ctx, "ReduceIterator", &iterator),
|
||||
done);
|
||||
|
||||
// Update the `done` callback to destroy the iterator before calling
|
||||
// the actual callback to avoid destruction races.
|
||||
IteratorBase* raw_iterator = iterator.release();
|
||||
done = std::bind(
|
||||
[raw_iterator](const std::function<void()>& done) {
|
||||
delete raw_iterator;
|
||||
done();
|
||||
},
|
||||
std::move(done));
|
||||
|
||||
// Iterate through the input dataset.
|
||||
Status status;
|
||||
while (true) {
|
||||
OP_REQUIRES_ASYNC(ctx, !ctx->cancellation_manager()->IsCancelled(),
|
||||
errors::Cancelled("Operation was cancelled"),
|
||||
done);
|
||||
std::vector<Tensor> next_input_element;
|
||||
bool end_of_input;
|
||||
status = raw_iterator->GetNext(&iter_ctx, &next_input_element,
|
||||
&end_of_input);
|
||||
if (!status.ok() || end_of_input) {
|
||||
break;
|
||||
}
|
||||
|
||||
// Run the reduce function to update the current state.
|
||||
std::vector<Tensor> args;
|
||||
args.reserve(state.size() + next_input_element.size());
|
||||
std::copy(state.begin(), state.end(), std::back_inserter(args));
|
||||
std::copy(next_input_element.begin(), next_input_element.end(),
|
||||
std::back_inserter(args));
|
||||
|
||||
std::vector<Tensor> reduce_func_output;
|
||||
status = instantiated_captured_func->Run(&iter_ctx, std::move(args),
|
||||
&reduce_func_output);
|
||||
if (!status.ok()) {
|
||||
break;
|
||||
}
|
||||
OP_REQUIRES_ASYNC(
|
||||
ctx, reduce_func_output.size() == state.size(),
|
||||
errors::InvalidArgument(
|
||||
"The number of components of the initial state and the "
|
||||
"reduce "
|
||||
"function output does not match. (initial_state=",
|
||||
state.size(), ", output=", reduce_func_output.size(), ")."),
|
||||
done);
|
||||
std::swap(reduce_func_output, state);
|
||||
}
|
||||
|
||||
if (!status.ok()) {
|
||||
ctx->SetStatus(status);
|
||||
done();
|
||||
return;
|
||||
}
|
||||
|
||||
OP_REQUIRES_ASYNC(ctx, state.size() == output_types_.size(),
|
||||
errors::InvalidArgument(
|
||||
"The number of result elements does not match "
|
||||
"the size of output types: ",
|
||||
state.size(), " vs. ", output_types_.size()),
|
||||
done);
|
||||
OP_REQUIRES_ASYNC(ctx, state.size() == output_shapes_.size(),
|
||||
errors::InvalidArgument(
|
||||
"The number of result elements does not match "
|
||||
"the size of output shapes: ",
|
||||
state.size(), " vs. ", output_shapes_.size()),
|
||||
done);
|
||||
for (int i = 0; i < state.size(); ++i) {
|
||||
OP_REQUIRES_ASYNC(
|
||||
ctx, state[i].dtype() == output_types_[i],
|
||||
errors::InvalidArgument(
|
||||
"The result does not match the expected type for "
|
||||
"component ",
|
||||
i, ". Expected: ", DataTypeString(output_types_[i]),
|
||||
". Actual: ", DataTypeString(state[i].dtype()), "."),
|
||||
done);
|
||||
OP_REQUIRES_ASYNC(
|
||||
ctx, output_shapes_[i].IsCompatibleWith(state[i].shape()),
|
||||
errors::InvalidArgument(
|
||||
"The result does not match the expected shape for "
|
||||
"component ",
|
||||
i, ". Expected: ", output_shapes_[i].DebugString(),
|
||||
". Actual: ", state[i].shape().DebugString(), "."),
|
||||
done);
|
||||
ctx->set_output(i, state[i]);
|
||||
}
|
||||
done();
|
||||
},
|
||||
std::move(done)));
|
||||
background_worker_.Schedule([this, ctx, done = std::move(done)]() {
|
||||
OP_REQUIRES_OK_ASYNC(ctx, DoCompute(ctx), done);
|
||||
done();
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
Status DoCompute(OpKernelContext* ctx) {
|
||||
DatasetBase* dataset;
|
||||
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(ctx->input(0), &dataset));
|
||||
OpInputList inputs;
|
||||
TF_RETURN_IF_ERROR(ctx->input_list("initial_state", &inputs));
|
||||
std::vector<Tensor> state(inputs.begin(), inputs.end());
|
||||
|
||||
std::unique_ptr<CapturedFunction> captured_func;
|
||||
TF_RETURN_IF_ERROR(CapturedFunction::Create(
|
||||
ctx, func_metadata_, "other_arguments", &captured_func));
|
||||
|
||||
IteratorContext::Params params(ctx);
|
||||
auto function_handle_cache =
|
||||
absl::make_unique<FunctionHandleCache>(params.flr);
|
||||
params.function_handle_cache = function_handle_cache.get();
|
||||
ResourceMgr resource_mgr;
|
||||
params.resource_mgr = &resource_mgr;
|
||||
CancellationManager cancellation_manager;
|
||||
params.cancellation_manager = &cancellation_manager;
|
||||
std::function<void()> deregister_fn;
|
||||
TF_RETURN_IF_ERROR(RegisterCancellationCallback(
|
||||
ctx->cancellation_manager(),
|
||||
[cm = params.cancellation_manager]() { cm->StartCancel(); },
|
||||
&deregister_fn));
|
||||
auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
|
||||
|
||||
IteratorContext iter_ctx(std::move(params));
|
||||
std::unique_ptr<InstantiatedCapturedFunction> instantiated_captured_func;
|
||||
TF_RETURN_IF_ERROR(
|
||||
captured_func->Instantiate(&iter_ctx, &instantiated_captured_func));
|
||||
|
||||
std::unique_ptr<IteratorBase> iterator;
|
||||
TF_RETURN_IF_ERROR(
|
||||
dataset->MakeIterator(&iter_ctx, "ReduceIterator", &iterator));
|
||||
|
||||
// Iterate through the input dataset.
|
||||
while (true) {
|
||||
if (ctx->cancellation_manager()->IsCancelled()) {
|
||||
return errors::Cancelled("Operation was cancelled");
|
||||
}
|
||||
std::vector<Tensor> next_input_element;
|
||||
bool end_of_input;
|
||||
TF_RETURN_IF_ERROR(
|
||||
iterator->GetNext(&iter_ctx, &next_input_element, &end_of_input));
|
||||
if (end_of_input) {
|
||||
break;
|
||||
}
|
||||
|
||||
// Run the reduce function to update the current state.
|
||||
std::vector<Tensor> args;
|
||||
args.reserve(state.size() + next_input_element.size());
|
||||
std::copy(state.begin(), state.end(), std::back_inserter(args));
|
||||
std::copy(next_input_element.begin(), next_input_element.end(),
|
||||
std::back_inserter(args));
|
||||
|
||||
std::vector<Tensor> reduce_func_output;
|
||||
TF_RETURN_IF_ERROR(instantiated_captured_func->Run(
|
||||
&iter_ctx, std::move(args), &reduce_func_output));
|
||||
if (reduce_func_output.size() != state.size()) {
|
||||
return errors::InvalidArgument(
|
||||
"The number of components of the initial state and the "
|
||||
"reduce "
|
||||
"function output does not match. (initial_state=",
|
||||
state.size(), ", output=", reduce_func_output.size(), ").");
|
||||
}
|
||||
std::swap(reduce_func_output, state);
|
||||
}
|
||||
|
||||
if (state.size() != output_types_.size()) {
|
||||
return errors::InvalidArgument(
|
||||
"The number of result elements does not match "
|
||||
"the size of output types: ",
|
||||
state.size(), " vs. ", output_types_.size());
|
||||
}
|
||||
if (state.size() != output_shapes_.size()) {
|
||||
return errors::InvalidArgument(
|
||||
"The number of result elements does not match "
|
||||
"the size of output shapes: ",
|
||||
state.size(), " vs. ", output_shapes_.size());
|
||||
}
|
||||
for (size_t i = 0; i < state.size(); ++i) {
|
||||
if (state[i].dtype() != output_types_[i]) {
|
||||
return errors::InvalidArgument(
|
||||
"The result does not match the expected type for "
|
||||
"component ",
|
||||
i, ". Expected: ", DataTypeString(output_types_[i]),
|
||||
". Actual: ", DataTypeString(state[i].dtype()), ".");
|
||||
}
|
||||
if (!output_shapes_[i].IsCompatibleWith(state[i].shape())) {
|
||||
return errors::InvalidArgument(
|
||||
"The result does not match the expected shape for "
|
||||
"component ",
|
||||
i, ". Expected: ", output_shapes_[i].DebugString(),
|
||||
". Actual: ", state[i].shape().DebugString(), ".");
|
||||
}
|
||||
ctx->set_output(i, state[i]);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<FunctionMetadata> func_metadata_ = nullptr;
|
||||
DataTypeVector output_types_;
|
||||
std::vector<PartialTensorShape> output_shapes_;
|
||||
|
Loading…
x
Reference in New Issue
Block a user