Merge pull request #41609 from dnguyen28061:TF_GetName
PiperOrigin-RevId: 324061940 Change-Id: I9519ccad6fd4958e16e3d8abbf2b011f89d574a1
This commit is contained in:
commit
f999bb1785
tensorflow/c
@ -125,6 +125,14 @@ TF_CAPI_EXPORT extern void TF_DeleteBuffer(TF_Buffer*);
|
||||
|
||||
TF_CAPI_EXPORT extern TF_Buffer TF_GetBuffer(TF_Buffer* buffer);
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Used to return strings across the C API. The caller does not take ownership
|
||||
// of the underlying data pointer and is not responsible for freeing it.
|
||||
typedef struct TF_StringView {
|
||||
const char* data;
|
||||
size_t len;
|
||||
} TF_StringView;
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// TF_SessionOptions holds options that can be passed during session creation.
|
||||
typedef struct TF_SessionOptions TF_SessionOptions;
|
||||
|
@ -239,6 +239,14 @@ void TF_OpKernelContext_Failure(TF_OpKernelContext* ctx, TF_Status* status) {
|
||||
DEFINE_TF_GETATTR(Type, TF_DataType, tensorflow::DataType)
|
||||
DEFINE_TF_GETATTR(Int32, tensorflow::int32, int32_t)
|
||||
|
||||
TF_StringView TF_OpKernelConstruction_GetName(TF_OpKernelConstruction* ctx) {
|
||||
auto* cc_ctx = reinterpret_cast<tensorflow::OpKernelConstruction*>(ctx);
|
||||
TF_StringView string_view_of_name;
|
||||
string_view_of_name.data = cc_ctx->def().name().data();
|
||||
string_view_of_name.len = cc_ctx->def().name().length();
|
||||
return string_view_of_name;
|
||||
}
|
||||
|
||||
TF_DataType TF_ExpectedOutputDataType(TF_OpKernelContext* ctx, int i) {
|
||||
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx);
|
||||
return static_cast<TF_DataType>(cc_ctx->expected_output_dtype(i));
|
||||
@ -271,4 +279,4 @@ TF_Tensor* TF_AllocateOutput(TF_OpKernelContext* context, int index,
|
||||
return nullptr;
|
||||
}
|
||||
return tf_tensor;
|
||||
}
|
||||
}
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
@ -184,6 +185,10 @@ TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt32(
|
||||
TF_OpKernelConstruction* ctx, const char* attr_name, int32_t* val,
|
||||
TF_Status* status);
|
||||
|
||||
// Returns the unique operation name for this OpKernel.
|
||||
TF_CAPI_EXPORT extern TF_StringView TF_OpKernelConstruction_GetName(
|
||||
TF_OpKernelConstruction* ctx);
|
||||
|
||||
// Allocates Tensor for output at given index. Caller takes ownership of
|
||||
// returned TF_Tensor and should deallocate it using TF_DeleteTensor(tensor).
|
||||
//
|
||||
|
@ -73,6 +73,12 @@ static void* MyCreateFunc(TF_OpKernelConstruction* ctx) {
|
||||
EXPECT_EQ(TF_FLOAT, type);
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
// Exercise kernel NodeDef name read
|
||||
TF_StringView name_string_view = TF_OpKernelConstruction_GetName(ctx);
|
||||
std::string node_name = "SomeNodeName";
|
||||
std::string candidate_node_name =
|
||||
std::string(name_string_view.data, name_string_view.len);
|
||||
EXPECT_EQ(node_name, candidate_node_name);
|
||||
return s;
|
||||
}
|
||||
|
||||
@ -96,9 +102,11 @@ namespace tensorflow {
|
||||
|
||||
static std::unique_ptr<OpKernel> GetFakeKernel(const char* device_name,
|
||||
const char* op_name,
|
||||
const char* node_name,
|
||||
Status* status) {
|
||||
NodeDef def;
|
||||
def.set_op(op_name);
|
||||
def.set_name(node_name);
|
||||
def.set_device(device_name);
|
||||
def.add_input("input1");
|
||||
def.add_input("input2");
|
||||
@ -114,7 +122,7 @@ static std::unique_ptr<OpKernel> GetFakeKernel(const char* device_name,
|
||||
// Tests registration of a single C kernel and checks that calls through the
|
||||
// C/C++ boundary are being made.
|
||||
TEST(TestKernel, TestRegisterKernelBuilder) {
|
||||
const char* kernel_name = "SomeKernelName";
|
||||
const char* node_name = "SomeNodeName";
|
||||
const char* op_name = "FooOp";
|
||||
const char* device_name = "FakeDeviceName1";
|
||||
|
||||
@ -129,7 +137,7 @@ TEST(TestKernel, TestRegisterKernelBuilder) {
|
||||
|
||||
{
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_RegisterKernelBuilder(kernel_name, builder, status);
|
||||
TF_RegisterKernelBuilder(node_name, builder, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status));
|
||||
TF_Buffer* buf = TF_GetRegisteredKernelsForOp(op_name, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status));
|
||||
@ -144,7 +152,7 @@ TEST(TestKernel, TestRegisterKernelBuilder) {
|
||||
{
|
||||
Status status;
|
||||
std::unique_ptr<OpKernel> kernel =
|
||||
GetFakeKernel(device_name, op_name, &status);
|
||||
GetFakeKernel(device_name, op_name, node_name, &status);
|
||||
TF_EXPECT_OK(status);
|
||||
ASSERT_NE(nullptr, kernel.get());
|
||||
kernel->Compute(nullptr);
|
||||
@ -162,7 +170,7 @@ class DummyDevice : public DeviceBase {
|
||||
};
|
||||
|
||||
TEST(TestKernel, TestInputAndOutputCount) {
|
||||
const char* kernel_name = "InputOutputCounterKernel";
|
||||
const char* node_name = "InputOutputCounterKernel";
|
||||
const char* op_name = "BarOp";
|
||||
const char* device_name = "FakeDeviceName2";
|
||||
|
||||
@ -212,7 +220,7 @@ TEST(TestKernel, TestInputAndOutputCount) {
|
||||
|
||||
{
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_RegisterKernelBuilder(kernel_name, builder, status);
|
||||
TF_RegisterKernelBuilder(node_name, builder, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status));
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
@ -233,7 +241,7 @@ TEST(TestKernel, TestInputAndOutputCount) {
|
||||
|
||||
Status status;
|
||||
std::unique_ptr<OpKernel> kernel =
|
||||
GetFakeKernel(device_name, op_name, &status);
|
||||
GetFakeKernel(device_name, op_name, node_name, &status);
|
||||
TF_EXPECT_OK(status);
|
||||
ASSERT_NE(nullptr, kernel.get());
|
||||
|
||||
@ -252,7 +260,7 @@ TEST(TestKernel, DeleteKernelBuilderIsOkOnNull) {
|
||||
}
|
||||
|
||||
TEST(TestKernel, TestTypeConstraint) {
|
||||
const char* kernel_name = "SomeKernelName";
|
||||
const char* node_name = "SomeNodeName";
|
||||
const char* op_name = "TypeOp";
|
||||
const char* device_name = "FakeDeviceName1";
|
||||
|
||||
@ -267,7 +275,7 @@ TEST(TestKernel, TestTypeConstraint) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_KernelBuilder_TypeConstraint(builder, "T", TF_DataType::TF_INT32, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status));
|
||||
TF_RegisterKernelBuilder(kernel_name, builder, status);
|
||||
TF_RegisterKernelBuilder(node_name, builder, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status));
|
||||
|
||||
TF_Buffer* buf = TF_GetRegisteredKernelsForOp(op_name, status);
|
||||
@ -296,7 +304,7 @@ TEST(TestKernel, TestTypeConstraint) {
|
||||
}
|
||||
|
||||
TEST(TestKernel, TestHostMemory) {
|
||||
const char* kernel_name = "SomeKernelName";
|
||||
const char* node_name = "SomeNodeName";
|
||||
const char* op_name = "HostMemoryOp";
|
||||
const char* device_name = "FakeDeviceName1";
|
||||
|
||||
@ -311,7 +319,7 @@ TEST(TestKernel, TestHostMemory) {
|
||||
TF_KernelBuilder_HostMemory(builder, "input2");
|
||||
TF_KernelBuilder_HostMemory(builder, "output1");
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_RegisterKernelBuilder(kernel_name, builder, status);
|
||||
TF_RegisterKernelBuilder(node_name, builder, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status));
|
||||
|
||||
TF_Buffer* buf = TF_GetRegisteredKernelsForOp(op_name, status);
|
||||
@ -335,12 +343,12 @@ TEST(TestKernel, TestHostMemory) {
|
||||
|
||||
class DeviceKernelOpTest : public OpsTestBase {
|
||||
protected:
|
||||
void SetupOp(const char* op_name, const char* kernel_name,
|
||||
void SetupOp(const char* op_name, const char* node_name,
|
||||
void (*compute_func)(void*, TF_OpKernelContext*)) {
|
||||
TF_KernelBuilder* builder = TF_NewKernelBuilder(
|
||||
op_name, device_name_, nullptr, compute_func, nullptr);
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_RegisterKernelBuilder(kernel_name, builder, status);
|
||||
TF_RegisterKernelBuilder(node_name, builder, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status));
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user