Merge pull request from dnguyen28061:TF_GetName

PiperOrigin-RevId: 324061940
Change-Id: I9519ccad6fd4958e16e3d8abbf2b011f89d574a1
This commit is contained in:
TensorFlower Gardener 2020-07-30 13:05:48 -07:00
commit f999bb1785
4 changed files with 42 additions and 13 deletions

View File

@ -125,6 +125,14 @@ TF_CAPI_EXPORT extern void TF_DeleteBuffer(TF_Buffer*);
TF_CAPI_EXPORT extern TF_Buffer TF_GetBuffer(TF_Buffer* buffer);
// --------------------------------------------------------------------------
// Used to return strings across the C API. The caller does not take ownership
// of the underlying data pointer and is not responsible for freeing it.
typedef struct TF_StringView {
const char* data;
size_t len;
} TF_StringView;
// --------------------------------------------------------------------------
// TF_SessionOptions holds options that can be passed during session creation.
typedef struct TF_SessionOptions TF_SessionOptions;

View File

@ -239,6 +239,14 @@ void TF_OpKernelContext_Failure(TF_OpKernelContext* ctx, TF_Status* status) {
DEFINE_TF_GETATTR(Type, TF_DataType, tensorflow::DataType)
DEFINE_TF_GETATTR(Int32, tensorflow::int32, int32_t)
TF_StringView TF_OpKernelConstruction_GetName(TF_OpKernelConstruction* ctx) {
auto* cc_ctx = reinterpret_cast<tensorflow::OpKernelConstruction*>(ctx);
TF_StringView string_view_of_name;
string_view_of_name.data = cc_ctx->def().name().data();
string_view_of_name.len = cc_ctx->def().name().length();
return string_view_of_name;
}
TF_DataType TF_ExpectedOutputDataType(TF_OpKernelContext* ctx, int i) {
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
return static_cast<TF_DataType>(cc_ctx->expected_output_dtype(i));
@ -271,4 +279,4 @@ TF_Tensor* TF_AllocateOutput(TF_OpKernelContext* context, int index,
return nullptr;
}
return tf_tensor;
}
}

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <stdint.h>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
@ -184,6 +185,10 @@ TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt32(
TF_OpKernelConstruction* ctx, const char* attr_name, int32_t* val,
TF_Status* status);
// Returns the unique operation name for this OpKernel.
TF_CAPI_EXPORT extern TF_StringView TF_OpKernelConstruction_GetName(
TF_OpKernelConstruction* ctx);
// Allocates Tensor for output at given index. Caller takes ownership of
// returned TF_Tensor and should deallocate it using TF_DeleteTensor(tensor).
//

View File

@ -73,6 +73,12 @@ static void* MyCreateFunc(TF_OpKernelConstruction* ctx) {
EXPECT_EQ(TF_FLOAT, type);
TF_DeleteStatus(status);
// Exercise kernel NodeDef name read
TF_StringView name_string_view = TF_OpKernelConstruction_GetName(ctx);
std::string node_name = "SomeNodeName";
std::string candidate_node_name =
std::string(name_string_view.data, name_string_view.len);
EXPECT_EQ(node_name, candidate_node_name);
return s;
}
@ -96,9 +102,11 @@ namespace tensorflow {
static std::unique_ptr<OpKernel> GetFakeKernel(const char* device_name,
const char* op_name,
const char* node_name,
Status* status) {
NodeDef def;
def.set_op(op_name);
def.set_name(node_name);
def.set_device(device_name);
def.add_input("input1");
def.add_input("input2");
@ -114,7 +122,7 @@ static std::unique_ptr<OpKernel> GetFakeKernel(const char* device_name,
// Tests registration of a single C kernel and checks that calls through the
// C/C++ boundary are being made.
TEST(TestKernel, TestRegisterKernelBuilder) {
const char* kernel_name = "SomeKernelName";
const char* node_name = "SomeNodeName";
const char* op_name = "FooOp";
const char* device_name = "FakeDeviceName1";
@ -129,7 +137,7 @@ TEST(TestKernel, TestRegisterKernelBuilder) {
{
TF_Status* status = TF_NewStatus();
TF_RegisterKernelBuilder(kernel_name, builder, status);
TF_RegisterKernelBuilder(node_name, builder, status);
EXPECT_EQ(TF_OK, TF_GetCode(status));
TF_Buffer* buf = TF_GetRegisteredKernelsForOp(op_name, status);
EXPECT_EQ(TF_OK, TF_GetCode(status));
@ -144,7 +152,7 @@ TEST(TestKernel, TestRegisterKernelBuilder) {
{
Status status;
std::unique_ptr<OpKernel> kernel =
GetFakeKernel(device_name, op_name, &status);
GetFakeKernel(device_name, op_name, node_name, &status);
TF_EXPECT_OK(status);
ASSERT_NE(nullptr, kernel.get());
kernel->Compute(nullptr);
@ -162,7 +170,7 @@ class DummyDevice : public DeviceBase {
};
TEST(TestKernel, TestInputAndOutputCount) {
const char* kernel_name = "InputOutputCounterKernel";
const char* node_name = "InputOutputCounterKernel";
const char* op_name = "BarOp";
const char* device_name = "FakeDeviceName2";
@ -212,7 +220,7 @@ TEST(TestKernel, TestInputAndOutputCount) {
{
TF_Status* status = TF_NewStatus();
TF_RegisterKernelBuilder(kernel_name, builder, status);
TF_RegisterKernelBuilder(node_name, builder, status);
EXPECT_EQ(TF_OK, TF_GetCode(status));
TF_DeleteStatus(status);
}
@ -233,7 +241,7 @@ TEST(TestKernel, TestInputAndOutputCount) {
Status status;
std::unique_ptr<OpKernel> kernel =
GetFakeKernel(device_name, op_name, &status);
GetFakeKernel(device_name, op_name, node_name, &status);
TF_EXPECT_OK(status);
ASSERT_NE(nullptr, kernel.get());
@ -252,7 +260,7 @@ TEST(TestKernel, DeleteKernelBuilderIsOkOnNull) {
}
TEST(TestKernel, TestTypeConstraint) {
const char* kernel_name = "SomeKernelName";
const char* node_name = "SomeNodeName";
const char* op_name = "TypeOp";
const char* device_name = "FakeDeviceName1";
@ -267,7 +275,7 @@ TEST(TestKernel, TestTypeConstraint) {
TF_Status* status = TF_NewStatus();
TF_KernelBuilder_TypeConstraint(builder, "T", TF_DataType::TF_INT32, status);
EXPECT_EQ(TF_OK, TF_GetCode(status));
TF_RegisterKernelBuilder(kernel_name, builder, status);
TF_RegisterKernelBuilder(node_name, builder, status);
EXPECT_EQ(TF_OK, TF_GetCode(status));
TF_Buffer* buf = TF_GetRegisteredKernelsForOp(op_name, status);
@ -296,7 +304,7 @@ TEST(TestKernel, TestTypeConstraint) {
}
TEST(TestKernel, TestHostMemory) {
const char* kernel_name = "SomeKernelName";
const char* node_name = "SomeNodeName";
const char* op_name = "HostMemoryOp";
const char* device_name = "FakeDeviceName1";
@ -311,7 +319,7 @@ TEST(TestKernel, TestHostMemory) {
TF_KernelBuilder_HostMemory(builder, "input2");
TF_KernelBuilder_HostMemory(builder, "output1");
TF_Status* status = TF_NewStatus();
TF_RegisterKernelBuilder(kernel_name, builder, status);
TF_RegisterKernelBuilder(node_name, builder, status);
EXPECT_EQ(TF_OK, TF_GetCode(status));
TF_Buffer* buf = TF_GetRegisteredKernelsForOp(op_name, status);
@ -335,12 +343,12 @@ TEST(TestKernel, TestHostMemory) {
class DeviceKernelOpTest : public OpsTestBase {
protected:
void SetupOp(const char* op_name, const char* kernel_name,
void SetupOp(const char* op_name, const char* node_name,
void (*compute_func)(void*, TF_OpKernelContext*)) {
TF_KernelBuilder* builder = TF_NewKernelBuilder(
op_name, device_name_, nullptr, compute_func, nullptr);
TF_Status* status = TF_NewStatus();
TF_RegisterKernelBuilder(kernel_name, builder, status);
TF_RegisterKernelBuilder(node_name, builder, status);
EXPECT_EQ(TF_OK, TF_GetCode(status));
TF_DeleteStatus(status);