Merge pull request from dnguyen28061:forward_input_or_allocate_output

PiperOrigin-RevId: 327713174
Change-Id: I334d9a34790d839be1b9dd3bb422dcf0a780d1c4
This commit is contained in:
TensorFlower Gardener 2020-08-20 16:08:31 -07:00
commit 66603e9220
3 changed files with 109 additions and 0 deletions

View File

@ -280,6 +280,36 @@ TF_Tensor* TF_AllocateOutput(TF_OpKernelContext* context, int index,
return tf_tensor;
}
TF_Tensor* TF_ForwardInputOrAllocateOutput(
TF_OpKernelContext* context, int* candidate_input_indices,
int num_candidate_input_indices, int output_index, int64_t* output_dims,
int output_num_dims, int* forwarded_input, TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(context);
static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
"64-bit int types should match in size");
tensorflow::gtl::ArraySlice<int> input_indices_array(
candidate_input_indices, num_candidate_input_indices);
tensorflow::gtl::ArraySlice<tensorflow::int64> output_dimarray(
reinterpret_cast<tensorflow::int64*>(output_dims), output_num_dims);
tensorflow::Tensor* output_tensor_pointer;
tensorflow::Status s = cc_ctx->forward_input_or_allocate_output(
input_indices_array, output_index,
tensorflow::TensorShape(output_dimarray), &output_tensor_pointer,
forwarded_input);
if (!s.ok()) {
::tensorflow::Set_TF_Status_from_Status(status, s);
return nullptr;
}
TF_Tensor* tf_tensor_output = TF_TensorFromTensor(*output_tensor_pointer, &s);
if (!s.ok()) {
::tensorflow::Set_TF_Status_from_Status(status, s);
return nullptr;
}
return tf_tensor_output;
}
TF_Tensor* TF_AllocateTemp(TF_OpKernelContext* context, TF_DataType dtype,
int64_t* dims, int num_dims,
TF_AllocatorAttributes* attributes,

View File

@ -200,6 +200,17 @@ TF_CAPI_EXPORT TF_Tensor* TF_AllocateOutput(TF_OpKernelContext* context,
int64_t* dims, int num_dims,
size_t len, TF_Status* status);
// Tries to forward one of the inputs given in input_indices to
// output[output_index]. If none of the given inputs can be forwarded, calls
// allocate_output() to allocate a new output buffer. The index of the
// forwarded input will be assign to output argument forwarded_input (if it's
// not nullptr). If no inputs are forwarded, forwarded_input will be assigned
// -1.
TF_CAPI_EXPORT TF_Tensor* TF_ForwardInputOrAllocateOutput(
TF_OpKernelContext* context, int* candidate_input_indices,
int num_candidate_input_indices, int output_index, int64_t* output_dims,
int output_num_dims, int* forwarded_input, TF_Status* status);
// Allocates a temporary Tensor of the specified type and shape. The
// Tensor must not be used after kernel construction is
// complete.

View File

@ -565,6 +565,74 @@ TEST_F(DeviceKernelOpTest, TestAllocateTempSize2x3) {
output->DebugString(100));
}
TEST_F(DeviceKernelOpTest, TestForwardInputOrAllocateOutput) {
const char* node_name = "TestForwardInputOrAllocateOutputKernel";
const char* op_name = "BazOp";
const char* device_name = "FakeDeviceName";
REGISTER_OP(op_name)
.Input("input1: float")
.Input("input2: float")
.Output("output1: float")
.Attr("SomeDataTypeAttr: type");
// A kernel whose Compute function that forwards a scalar input to output
auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
TF_Status* s = TF_NewStatus();
int candidate_input_indices[1] = {0};
int forwarded_input;
int64_t output_dims[1] = {};
TF_Tensor* output = TF_ForwardInputOrAllocateOutput(
/*context=*/ctx, candidate_input_indices,
/*num_candidate_input_indices=*/1,
/*output_index=*/0, output_dims, /*output_num_dims=*/0,
&forwarded_input, /*status=*/s);
EXPECT_EQ(TF_OK, TF_GetCode(s));
EXPECT_EQ(forwarded_input, 0);
EXPECT_EQ(TF_FLOAT, TF_TensorType(output));
EXPECT_EQ(0, TF_NumDims(output));
TF_DeleteStatus(s);
TF_DeleteTensor(output);
};
TF_KernelBuilder* builder = TF_NewKernelBuilder(op_name, device_name, nullptr,
my_compute_func, nullptr);
{
TF_Status* status = TF_NewStatus();
TF_RegisterKernelBuilder(node_name, builder, status);
EXPECT_EQ(TF_OK, TF_GetCode(status));
TF_DeleteStatus(status);
}
{
OpKernelContext::Params p;
DummyDevice dummy_device(nullptr);
p.device = &dummy_device;
AllocatorAttributes alloc_attrs;
p.output_attr_array = &alloc_attrs;
Tensor t(123.0f);
gtl::InlinedVector<TensorValue, 4> inputs;
// GetFakeKernel requires a NodeDef with two inputs
inputs.emplace_back(&t);
inputs.emplace_back();
p.inputs = &inputs;
Status status;
std::unique_ptr<OpKernel> kernel =
GetFakeKernel(device_name, op_name, node_name, &status);
TF_EXPECT_OK(status);
ASSERT_NE(nullptr, kernel.get());
p.op_kernel = kernel.get();
OpKernelContext ctx(&p);
kernel->Compute(&ctx);
ASSERT_EQ(123, ctx.mutable_output(0)->scalar<float>()());
}
}
void validate_tensor(TF_Tensor* tensor, int64_t* dims, int64_t num_dims,
TF_DataType dtype) {
EXPECT_EQ(TF_FLOAT, TF_TensorType(tensor));