clean up only

This commit is contained in:
Daniel Nguyen 2020-08-11 18:18:25 +00:00
parent 0a79e71110
commit aa88605eae
3 changed files with 15 additions and 11 deletions

View File

@ -282,13 +282,16 @@ TF_Tensor* TF_AllocateOutput(TF_OpKernelContext* context, int index,
} }
TF_Tensor* TF_ForwardInputOrAllocateOutput(TF_OpKernelContext* context, TF_Tensor* TF_ForwardInputOrAllocateOutput(TF_OpKernelContext* context,
int* candidate_input_indices, int num_input_indices, int output_index, int* candidate_input_indices, int num_candidate_input_indices,
int64_t* output_dims, int output_num_dims, int* forwarded_input, int output_index, int64_t* output_dims, int output_num_dims,
TF_Status* status) { int* forwarded_input, TF_Status* status) {
TF_SetStatus(status, TF_OK, ""); TF_SetStatus(status, TF_OK, "");
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(context); auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(context);
static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
"64-bit int types should match in size");
tensorflow::gtl::ArraySlice<int> input_indices_array(candidate_input_indices, tensorflow::gtl::ArraySlice<int> input_indices_array(candidate_input_indices,
num_input_indices); num_candidate_input_indices);
tensorflow::gtl::ArraySlice<tensorflow::int64> output_dimarray( tensorflow::gtl::ArraySlice<tensorflow::int64> output_dimarray(
reinterpret_cast<tensorflow::int64*>(output_dims), output_num_dims); reinterpret_cast<tensorflow::int64*>(output_dims), output_num_dims);
tensorflow::Tensor* output_tensor_pointer; tensorflow::Tensor* output_tensor_pointer;

View File

@ -205,10 +205,9 @@ TF_CAPI_EXPORT TF_Tensor* TF_AllocateOutput(TF_OpKernelContext* context,
// forwarded input will be assign to output argument forwarded_input (if it's // forwarded input will be assign to output argument forwarded_input (if it's
// not nullptr). If no inputs are forwarded, forwarded_input will be assigned // not nullptr). If no inputs are forwarded, forwarded_input will be assigned
// -1. // -1.
TF_CAPI_EXPORT TF_Tensor* TF_ForwardInputOrAllocateOutput( TF_CAPI_EXPORT TF_Tensor* TF_ForwardInputOrAllocateOutput(
TF_OpKernelContext* context, int* candidate_input_indices, TF_OpKernelContext* context, int* candidate_input_indices,
int num_input_indices, int output_index, int64_t* output_dims, int num_candidate_input_indices, int output_index, int64_t* output_dims,
int output_num_dims, int* forwarded_input, TF_Status* status); int output_num_dims, int* forwarded_input, TF_Status* status);
#ifdef __cplusplus #ifdef __cplusplus

View File

@ -486,14 +486,16 @@ TEST_F(DeviceKernelOpTest, TestForwardInputOrAllocateOutput) {
.Output("output1: float") .Output("output1: float")
.Attr("SomeDataTypeAttr: type");; .Attr("SomeDataTypeAttr: type");;
// A kernel whose Compute function that forwards one input to output // A kernel whose Compute function that forwards a scalar input to output
auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) { auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
TF_Status* s = TF_NewStatus(); TF_Status* s = TF_NewStatus();
int candidate_input_indices[1] = {0}; int candidate_input_indices[1] = {0};
int forwarded_input; int forwarded_input;
int64_t output_dims[1] = {}; int64_t output_dims[1] = {};
TF_Tensor* output = TF_ForwardInputOrAllocateOutput(ctx, TF_Tensor* output = TF_ForwardInputOrAllocateOutput(/*context=*/ctx,
candidate_input_indices, 1, 0, output_dims, 0, &forwarded_input, s); candidate_input_indices, /*num_candidate_input_indices=*/1,
/*output_index=*/0, output_dims, /*output_num_dims=*/0,
&forwarded_input, /*status=*/s);
EXPECT_EQ(TF_OK, TF_GetCode(s)); EXPECT_EQ(TF_OK, TF_GetCode(s));
EXPECT_EQ(forwarded_input, 0); EXPECT_EQ(forwarded_input, 0);
EXPECT_EQ(TF_FLOAT, TF_TensorType(output)); EXPECT_EQ(TF_FLOAT, TF_TensorType(output));
@ -518,7 +520,7 @@ TEST_F(DeviceKernelOpTest, TestForwardInputOrAllocateOutput) {
AllocatorAttributes alloc_attrs; AllocatorAttributes alloc_attrs;
p.output_attr_array = &alloc_attrs; p.output_attr_array = &alloc_attrs;
Tensor t(static_cast<float>(123)); Tensor t(123.0f);
gtl::InlinedVector<TensorValue, 4> inputs; gtl::InlinedVector<TensorValue, 4> inputs;
// GetFakeKernel requires a NodeDef with two inputs // GetFakeKernel requires a NodeDef with two inputs