[OpKernel] Implement OpKernelContext::output_required(int).

This long-pending feature enables kernels to specialize for the case where not all of their outputs may need to be produced. As an example, this change uses `OpKernelContext::output_required()` in `SparseFillEmptyRowsOp`, which often produces unused outputs for the purpose of backpropagation; using `output_required()` we can save the related allocations in inference workloads.

I did consider implementing this as a Grappler rewrite pass, but the example of SparseFillEmptyRowsOp convinced me that (given the present state) a runtime test was better. The rewrite-based alternative would require me to have four similar op registrations for SparseFillEmptyRows (since we now have two optional outputs that may be present or absent, so there are 2^2 possible signatures), and to devise a mapping/registration scheme that the rewriter might use to substitute different implementations. By contrast, the runtime check is (i) pretty cheap compared to op dispatch, and (ii) easy to implement gradually and locally.

PiperOrigin-RevId: 297274989
Change-Id: I23b5207017921ba118bd4fc31dc54ff53fe4332d
This commit is contained in:
Derek Murray 2020-02-25 21:16:20 -08:00 committed by TensorFlower Gardener
parent 3755911efc
commit da13f69cfb
4 changed files with 53 additions and 23 deletions

View File

@ -714,10 +714,13 @@ Status ExecutorImpl::Initialize(const Graph& graph) {
used_outputs[e->src_output()] = true;
}
}
int i = 0;
for (bool used_output : used_outputs) {
if (!used_output) {
metrics::RecordUnusedOutput(n->type_string());
item->kernel->set_output_required(i, false);
}
++i;
}
}
@ -2093,9 +2096,9 @@ Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
for (int i = 0; i < item.num_outputs; ++i) {
const TensorValue val = ctx->release_output(i);
if (val.tensor == nullptr) {
// Unless it's a Switch or a Recv, the node must produce a
// tensor value at i-th output.
if (!item.is_recv_or_switch) {
// Unless it's a Switch or a Recv, or the executor has marked the output
// as not required, the node must produce a tensor value at i-th output.
if (!(item.is_recv_or_switch || !item.kernel->output_required(i))) {
s.Update(errors::Internal("Missing ", i, "-th output from ",
FormatNodeDefForError(item.kernel->def())));
}

View File

@ -105,7 +105,8 @@ OpKernel::OpKernel(OpKernelConstruction* context, bool is_deferred)
type_string_view_(props_->node_def.op()),
graph_def_version_(context->graph_def_version()),
is_deferred_(is_deferred),
cost_estimate_(OpKernel::kInitialCostEstimateCycles) {
cost_estimate_(OpKernel::kInitialCostEstimateCycles),
outputs_required_(context->num_outputs(), true) {
OP_REQUIRES_OK(context,
NameRangesForNode(props_->node_def, *props_->op_def,
&input_name_map_, &output_name_map_));
@ -133,7 +134,8 @@ OpKernel::OpKernel(OpKernelConstruction* context, NodeDef&& custom_def,
type_string_view_(props_->node_def.op()),
graph_def_version_(context->graph_def_version()),
is_deferred_(is_deferred),
cost_estimate_(OpKernel::kInitialCostEstimateCycles) {
cost_estimate_(OpKernel::kInitialCostEstimateCycles),
outputs_required_(context->num_outputs(), true) {
OP_REQUIRES_OK(context,
NameRangesForNode(props_->node_def, *props_->op_def,
&input_name_map_, &output_name_map_));

View File

@ -156,6 +156,18 @@ class OpKernel {
// Returns a pointer to the tensor stored inside constant ops.
virtual const Tensor* const_tensor() const { return nullptr; }
// Returns true if this kernel must produce its ith output.
// REQUIRES: 0 <= i < num_inputs().
bool output_required(int i) const { return outputs_required_[i]; }
// Hints whether or not the ith output must be produced when running the
// kernel. By default, all outputs are required. The kernel implementation
// may ignore the hint.
// REQUIRES: 0 <= i < num_inputs().
void set_output_required(int i, bool is_required) {
outputs_required_[i] = is_required;
}
// Updates the dynamic cost estimate, which is used to determine whether this
// op is expensive. The new cost estimate is a weighted average of the old
// cost estimate and the latest cost.
@ -223,6 +235,7 @@ class OpKernel {
const bool is_deferred_;
bool expensive_;
std::atomic_uint_fast64_t cost_estimate_;
std::vector<bool> outputs_required_;
TF_DISALLOW_COPY_AND_ASSIGN(OpKernel);
};
@ -941,10 +954,8 @@ class OpKernelContext {
// should call allocate_output(index, ...), set_output(index, ...),
// set_output_ref(index, ...), or set the status to a non-ok value.
// If it returns false, it may output, but is not required to do so.
// TODO(mrry): Convert this to return Status, and implement a string
// name version.
bool output_required(int index) const {
return true; // TODO(josh11b): implement
return op_kernel().output_required(index);
}
// Allocation of tensors during kernel execution inside the Compute

View File

@ -78,16 +78,23 @@ class SparseFillEmptyRowsOp : public OpKernel {
const int64 N = indices_t.shape().dim_size(0);
const int64 dense_rows = dense_shape(0);
Tensor* empty_row_indicator_t;
OP_REQUIRES_OK(context, context->allocate_output(kEmptyRowIndicatorOutput,
TensorShape({dense_rows}),
&empty_row_indicator_t));
auto empty_row_indicator = empty_row_indicator_t->vec<bool>();
Tensor* reverse_index_map_t;
OP_REQUIRES_OK(context, context->allocate_output(kReverseIndexMapOutput,
TensorShape({N}),
&reverse_index_map_t));
auto reverse_index_map = reverse_index_map_t->vec<int64>();
bool* empty_row_indicator = nullptr;
if (context->output_required(kEmptyRowIndicatorOutput)) {
Tensor* empty_row_indicator_t = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(kEmptyRowIndicatorOutput,
TensorShape({dense_rows}),
&empty_row_indicator_t));
empty_row_indicator = empty_row_indicator_t->vec<bool>().data();
}
int64* reverse_index_map = nullptr;
if (context->output_required(kReverseIndexMapOutput)) {
Tensor* reverse_index_map_t = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(kReverseIndexMapOutput,
TensorShape({N}),
&reverse_index_map_t));
reverse_index_map = reverse_index_map_t->vec<int64>().data();
}
int rank = indices_t.shape().dim_size(1);
@ -122,8 +129,11 @@ class SparseFillEmptyRowsOp : public OpKernel {
bool all_rows_full = true;
for (int row = 0; row < dense_rows; ++row) {
// csr_offset here describes the number of elements in this dense row
empty_row_indicator(row) = (csr_offset[row] == 0);
all_rows_full = all_rows_full & !empty_row_indicator(row);
bool row_empty = (csr_offset[row] == 0);
if (empty_row_indicator) {
empty_row_indicator[row] = row_empty;
}
all_rows_full = all_rows_full & !row_empty;
// In filled version, each row has at least one element.
csr_offset[row] = std::max(csr_offset[row], int64{1});
// Update csr_offset to represent the number of elements up to and
@ -140,8 +150,10 @@ class SparseFillEmptyRowsOp : public OpKernel {
if (all_rows_full) {
context->set_output(kOutputIndicesOutput, indices_t);
context->set_output(kOutputValuesOutput, values_t);
for (int64 i = 0; i < N; ++i) {
reverse_index_map(i) = i;
if (reverse_index_map) {
for (int64 i = 0; i < N; ++i) {
reverse_index_map[i] = i;
}
}
} else {
Tensor* output_indices_t;
@ -169,7 +181,9 @@ class SparseFillEmptyRowsOp : public OpKernel {
std::copy_n(&indices(i, 0), rank, &output_indices(output_i, 0));
output_values(output_i) = values(i);
// We'll need this reverse index map to backprop correctly.
reverse_index_map(i) = output_i;
if (reverse_index_map) {
reverse_index_map[i] = output_i;
}
}
// Fill in values for rows that are missing