finished implementation and passes tests
This commit is contained in:
parent
c5ef52c5f0
commit
0a79e71110
@ -281,26 +281,30 @@ TF_Tensor* TF_AllocateOutput(TF_OpKernelContext* context, int index,
|
|||||||
return tf_tensor;
|
return tf_tensor;
|
||||||
}
|
}
|
||||||
|
|
||||||
void TF_ForwardInputOrAllocateOutput(TF_OpKernelContext* context,
|
TF_Tensor* TF_ForwardInputOrAllocateOutput(TF_OpKernelContext* context,
|
||||||
int* candidate_input_indices, int num_input_indices, int output_index,
|
int* candidate_input_indices, int num_input_indices, int output_index,
|
||||||
int64_t* output_dims, int output_num_dims, TF_Tensor** output,
|
int64_t* output_dims, int output_num_dims, int* forwarded_input,
|
||||||
int* forwarded_input, TF_Status* status) {
|
TF_Status* status) {
|
||||||
TF_SetStatus(status, TF_OK, "");
|
TF_SetStatus(status, TF_OK, "");
|
||||||
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(context);
|
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(context);
|
||||||
tensorflow::gtl::ArraySlice<int> input_indices_array(candidate_input_indices,
|
tensorflow::gtl::ArraySlice<int> input_indices_array(candidate_input_indices,
|
||||||
num_input_indices);
|
num_input_indices);
|
||||||
tensorflow::gtl::ArraySlice<tensorflow::int64> output_dimarray(
|
tensorflow::gtl::ArraySlice<tensorflow::int64> output_dimarray(
|
||||||
reinterpret_cast<tensorflow::int64*>(output_dims), output_num_dims);
|
reinterpret_cast<tensorflow::int64*>(output_dims), output_num_dims);
|
||||||
tensorflow::Tensor output_tensor;
|
tensorflow::Tensor* output_tensor_pointer;
|
||||||
tensorflow::Status s = TF_TensorToTensor(*output, &output_tensor);
|
tensorflow::Status s = cc_ctx->
|
||||||
if (!s.ok()) {
|
|
||||||
::tensorflow::Set_TF_Status_from_Status(status, s);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
tensorflow::Tensor* output_tensor_pointer = &output_tensor;
|
|
||||||
tensorflow::Status forward_input_status = cc_ctx->
|
|
||||||
forward_input_or_allocate_output(input_indices_array, output_index,
|
forward_input_or_allocate_output(input_indices_array, output_index,
|
||||||
tensorflow::TensorShape(output_dimarray), &output_tensor_pointer,
|
tensorflow::TensorShape(output_dimarray), &output_tensor_pointer,
|
||||||
forwarded_input);
|
forwarded_input);
|
||||||
|
if (!s.ok()) {
|
||||||
::tensorflow::Set_TF_Status_from_Status(status, s);
|
::tensorflow::Set_TF_Status_from_Status(status, s);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
TF_Tensor* tf_tensor_output = TF_TensorFromTensor(
|
||||||
|
*output_tensor_pointer, &s);
|
||||||
|
if (!s.ok()) {
|
||||||
|
::tensorflow::Set_TF_Status_from_Status(status, s);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return tf_tensor_output;
|
||||||
}
|
}
|
||||||
|
@ -199,10 +199,17 @@ TF_CAPI_EXPORT TF_Tensor* TF_AllocateOutput(TF_OpKernelContext* context,
|
|||||||
int64_t* dims, int num_dims,
|
int64_t* dims, int num_dims,
|
||||||
size_t len, TF_Status* status);
|
size_t len, TF_Status* status);
|
||||||
|
|
||||||
TF_CAPI_EXPORT void TF_ForwardInputOrAllocateOutput(TF_OpKernelContext* context,
|
// Tries to forward one of the inputs given in input_indices to
|
||||||
int* candidate_input_indices, int num_input_indices, int output_index,
|
// output[output_index]. If none of the given inputs can be forwarded, calls
|
||||||
int64_t* output_dims, int output_num_dims, TF_Tensor** output,
|
// allocate_output() to allocate a new output buffer. The index of the
|
||||||
int* forwarded_input, TF_Status* status);
|
// forwarded input will be assign to output argument forwarded_input (if it's
|
||||||
|
// not nullptr). If no inputs are forwarded, forwarded_input will be assigned
|
||||||
|
// -1.
|
||||||
|
|
||||||
|
TF_CAPI_EXPORT TF_Tensor* TF_ForwardInputOrAllocateOutput(
|
||||||
|
TF_OpKernelContext* context, int* candidate_input_indices,
|
||||||
|
int num_input_indices, int output_index, int64_t* output_dims,
|
||||||
|
int output_num_dims, int* forwarded_input, TF_Status* status);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
} /* end extern "C" */
|
} /* end extern "C" */
|
||||||
|
@ -474,4 +474,68 @@ TEST_F(DeviceKernelOpTest, TestAllocateOutputSize2x3) {
|
|||||||
EXPECT_EQ("Tensor<type: float shape: [2,3] values: [1 2 3][4 5 6]>",
|
EXPECT_EQ("Tensor<type: float shape: [2,3] values: [1 2 3][4 5 6]>",
|
||||||
output->DebugString(100));
|
output->DebugString(100));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(DeviceKernelOpTest, TestForwardInputOrAllocateOutput) {
|
||||||
|
const char* node_name = "TestForwardInputOrAllocateOutputKernel";
|
||||||
|
const char* op_name = "BazOp";
|
||||||
|
const char* device_name = "FakeDeviceName";
|
||||||
|
|
||||||
|
REGISTER_OP(op_name)
|
||||||
|
.Input("input1: float")
|
||||||
|
.Input("input2: float")
|
||||||
|
.Output("output1: float")
|
||||||
|
.Attr("SomeDataTypeAttr: type");;
|
||||||
|
|
||||||
|
// A kernel whose Compute function that forwards one input to output
|
||||||
|
auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
|
||||||
|
TF_Status* s = TF_NewStatus();
|
||||||
|
int candidate_input_indices[1] = {0};
|
||||||
|
int forwarded_input;
|
||||||
|
int64_t output_dims[1] = {};
|
||||||
|
TF_Tensor* output = TF_ForwardInputOrAllocateOutput(ctx,
|
||||||
|
candidate_input_indices, 1, 0, output_dims, 0, &forwarded_input, s);
|
||||||
|
EXPECT_EQ(TF_OK, TF_GetCode(s));
|
||||||
|
EXPECT_EQ(forwarded_input, 0);
|
||||||
|
EXPECT_EQ(TF_FLOAT, TF_TensorType(output));
|
||||||
|
EXPECT_EQ(0, TF_NumDims(output));
|
||||||
|
TF_DeleteStatus(s);
|
||||||
|
};
|
||||||
|
|
||||||
|
TF_KernelBuilder* builder = TF_NewKernelBuilder(op_name, device_name, nullptr,
|
||||||
|
my_compute_func, nullptr);
|
||||||
|
|
||||||
|
{
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
TF_RegisterKernelBuilder(node_name, builder, status);
|
||||||
|
EXPECT_EQ(TF_OK, TF_GetCode(status));
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
OpKernelContext::Params p;
|
||||||
|
DummyDevice dummy_device(nullptr);
|
||||||
|
p.device = &dummy_device;
|
||||||
|
AllocatorAttributes alloc_attrs;
|
||||||
|
p.output_attr_array = &alloc_attrs;
|
||||||
|
|
||||||
|
Tensor t(static_cast<float>(123));
|
||||||
|
|
||||||
|
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||||
|
// GetFakeKernel requires a NodeDef with two inputs
|
||||||
|
inputs.emplace_back(&t);
|
||||||
|
inputs.emplace_back();
|
||||||
|
p.inputs = &inputs;
|
||||||
|
|
||||||
|
Status status;
|
||||||
|
std::unique_ptr<OpKernel> kernel =
|
||||||
|
GetFakeKernel(device_name, op_name, node_name, &status);
|
||||||
|
TF_EXPECT_OK(status);
|
||||||
|
ASSERT_NE(nullptr, kernel.get());
|
||||||
|
|
||||||
|
p.op_kernel = kernel.get();
|
||||||
|
OpKernelContext ctx(&p);
|
||||||
|
kernel->Compute(&ctx);
|
||||||
|
ASSERT_EQ(123, ctx.mutable_output(0)->scalar<float>()());
|
||||||
|
}
|
||||||
|
}
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
Loading…
x
Reference in New Issue
Block a user