321 lines
12 KiB
C++
321 lines
12 KiB
C++
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include "tensorflow/c/ops.h"
|
|
|
|
#include "absl/strings/str_cat.h"
|
|
#include "tensorflow/c/c_api.h"
|
|
#include "tensorflow/core/framework/attr_value.pb.h"
|
|
#include "tensorflow/core/framework/fake_input.h"
|
|
#include "tensorflow/core/framework/op_def.pb.h"
|
|
#include "tensorflow/core/framework/op_def_builder.h"
|
|
#include "tensorflow/core/framework/shape_inference_testutil.h"
|
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
|
#include "tensorflow/core/framework/types.pb.h"
|
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
|
#include "tensorflow/core/platform/test.h"
|
|
|
|
namespace tensorflow {
|
|
namespace {
|
|
|
|
TEST(OpsTest, TestBasicOpRegistration) {
|
|
TF_OpDefinitionBuilder* builder = TF_NewOpDefinitionBuilder("SomeOp");
|
|
TF_OpDefinitionBuilderAddAttr(builder, "attr1: string");
|
|
TF_OpDefinitionBuilderAddInput(builder, "input1: uint8");
|
|
TF_OpDefinitionBuilderAddInput(builder, "input2: uint16");
|
|
TF_OpDefinitionBuilderAddOutput(builder, "output1: uint32");
|
|
TF_Status* status = TF_NewStatus();
|
|
TF_RegisterOpDefinition(builder, status);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
|
TF_Buffer* op_list_buffer = TF_GetAllOpList();
|
|
::tensorflow::OpList op_list;
|
|
op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length);
|
|
bool found = false;
|
|
for (const auto& op : op_list.op()) {
|
|
if (op.name() == "SomeOp") {
|
|
ASSERT_EQ(2, op.input_arg_size());
|
|
ASSERT_EQ("input1", op.input_arg(0).name());
|
|
ASSERT_EQ(::tensorflow::DT_UINT8, op.input_arg(0).type());
|
|
ASSERT_EQ(1, op.attr_size());
|
|
ASSERT_EQ("string", op.attr(0).type());
|
|
found = true;
|
|
}
|
|
}
|
|
EXPECT_TRUE(found);
|
|
TF_DeleteStatus(status);
|
|
TF_DeleteBuffer(op_list_buffer);
|
|
}
|
|
|
|
void identity_shape_fn(TF_ShapeInferenceContext* ctx, TF_Status* status) {
|
|
TF_ShapeHandle* handle = TF_NewShapeHandle();
|
|
TF_ShapeInferenceContextGetInput(ctx, 0, handle, status);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(status));
|
|
TF_ShapeInferenceContextSetOutput(ctx, 0, handle, status);
|
|
TF_DeleteShapeHandle(handle);
|
|
}
|
|
|
|
TEST(OpsTest, TestShapeInference_IdentityFunction) {
|
|
ShapeInferenceTestOp op("SomeTestOp");
|
|
|
|
TF_OpDefinitionBuilder* builder = TF_NewOpDefinitionBuilder("SomeTestOp");
|
|
TF_OpDefinitionBuilderAddInput(builder, "input1: uint8");
|
|
TF_OpDefinitionBuilderAddOutput(builder, "output1: uint8");
|
|
TF_OpDefinitionBuilderSetShapeInferenceFunction(builder, &identity_shape_fn);
|
|
TF_Status* status = TF_NewStatus();
|
|
TF_RegisterOpDefinition(builder, status);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
|
|
|
TF_ASSERT_OK(
|
|
shape_inference::ShapeInferenceTestutil::InferShapes(op, "[1,2]", "in0"));
|
|
TF_DeleteStatus(status);
|
|
}
|
|
|
|
TEST(OpsTest, TestShapeInference_UnknownShape) {
|
|
ShapeInferenceTestOp op("UnknownShapeOp");
|
|
|
|
TF_OpDefinitionBuilder* builder = TF_NewOpDefinitionBuilder("UnknownShapeOp");
|
|
TF_OpDefinitionBuilderAddInput(builder, "input1: uint8");
|
|
TF_OpDefinitionBuilderAddInput(builder, "input2: uint32");
|
|
TF_OpDefinitionBuilderAddOutput(builder, "output1: uint8");
|
|
TF_OpDefinitionBuilderAddOutput(builder, "output2: uint8");
|
|
TF_OpDefinitionBuilderSetShapeInferenceFunction(
|
|
builder, &TF_ShapeInferenceContextSetUnknownShape);
|
|
TF_Status* status = TF_NewStatus();
|
|
TF_RegisterOpDefinition(builder, status);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
|
|
|
TF_ASSERT_OK(shape_inference::ShapeInferenceTestutil::InferShapes(
|
|
op, "[1,2];[3,4]", "?;?"));
|
|
TF_DeleteStatus(status);
|
|
}
|
|
|
|
// Creates an output whose shape is a vector of length
|
|
// TF_ShapeInferenceContextRank.
|
|
void vectorize_shape_fn(TF_ShapeInferenceContext* ctx, TF_Status* status) {
|
|
TF_ShapeHandle* handle = TF_NewShapeHandle();
|
|
TF_ShapeInferenceContextGetInput(ctx, 0, handle, status);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(status));
|
|
TF_ShapeHandle* new_shape = TF_ShapeInferenceContextVectorFromSize(
|
|
ctx, TF_ShapeInferenceContextRank(ctx, handle));
|
|
TF_ShapeInferenceContextSetOutput(ctx, 0, new_shape, status);
|
|
TF_DeleteShapeHandle(handle);
|
|
TF_DeleteShapeHandle(new_shape);
|
|
}
|
|
|
|
TEST(OpsTest, TestShapeInference_VectorizeFunction) {
|
|
ShapeInferenceTestOp op("VectorizeTestOp");
|
|
|
|
TF_OpDefinitionBuilder* builder =
|
|
TF_NewOpDefinitionBuilder("VectorizeTestOp");
|
|
TF_OpDefinitionBuilderAddInput(builder, "input1: uint8");
|
|
TF_OpDefinitionBuilderAddOutput(builder, "output1: uint8");
|
|
TF_OpDefinitionBuilderSetShapeInferenceFunction(builder, &vectorize_shape_fn);
|
|
TF_Status* status = TF_NewStatus();
|
|
TF_RegisterOpDefinition(builder, status);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
|
|
|
TF_ASSERT_OK(shape_inference::ShapeInferenceTestutil::InferShapes(
|
|
op, "[4,5,9]", "[3]"));
|
|
TF_DeleteStatus(status);
|
|
}
|
|
|
|
TEST(OpsTest, AttributeAccessors) {
|
|
TF_OpDefinitionBuilder* builder =
|
|
TF_NewOpDefinitionBuilder("AttributeAccessorsOp");
|
|
TF_OpDefinitionBuilderAddAttr(builder, "foo1: int >= 2");
|
|
TF_OpDefinitionBuilderAddAttr(builder, "foo2: string=\"my string\"");
|
|
TF_OpDefinitionBuilderSetIsCommutative(builder, true);
|
|
TF_OpDefinitionBuilderSetIsAggregate(builder, true);
|
|
TF_OpDefinitionBuilderSetAllowsUninitializedInput(builder, true);
|
|
std::string deprecation_msg = "use something else instead";
|
|
TF_OpDefinitionBuilderDeprecated(builder, 4, deprecation_msg.c_str());
|
|
|
|
TF_Status* status = TF_NewStatus();
|
|
TF_RegisterOpDefinition(builder, status);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(status));
|
|
|
|
TF_Buffer* op_list_buffer = TF_GetAllOpList();
|
|
::tensorflow::OpList op_list;
|
|
op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length);
|
|
bool found = false;
|
|
for (const auto& op : op_list.op()) {
|
|
if (op.name() == "AttributeAccessorsOp") {
|
|
ASSERT_TRUE(op.is_commutative());
|
|
ASSERT_TRUE(op.is_aggregate());
|
|
ASSERT_TRUE(op.allows_uninitialized_input());
|
|
ASSERT_EQ(4, op.deprecation().version());
|
|
ASSERT_EQ(deprecation_msg, op.deprecation().explanation());
|
|
ASSERT_EQ(2, op.attr_size());
|
|
ASSERT_EQ("int", op.attr(0).type());
|
|
ASSERT_EQ(2, op.attr(0).minimum());
|
|
ASSERT_EQ("string", op.attr(1).type());
|
|
ASSERT_EQ("my string", op.attr(1).default_value().s());
|
|
found = true;
|
|
}
|
|
}
|
|
ASSERT_TRUE(found);
|
|
TF_DeleteStatus(status);
|
|
TF_DeleteBuffer(op_list_buffer);
|
|
}
|
|
|
|
#define C_CTX(x) reinterpret_cast<TF_ShapeInferenceContext*>(x)
|
|
#define C_SHP(x) reinterpret_cast<TF_ShapeHandle*>(x)
|
|
|
|
static OpDef MakeOpDef(int num_inputs, int num_outputs) {
|
|
OpRegistrationData op_reg_data;
|
|
OpDefBuilder b("dummy");
|
|
for (int i = 0; i < num_inputs; ++i) {
|
|
b.Input(strings::StrCat("i", i, ": float"));
|
|
}
|
|
for (int i = 0; i < num_outputs; ++i) {
|
|
b.Output(strings::StrCat("o", i, ": float"));
|
|
}
|
|
CHECK(b.Attr("foo:string").Finalize(&op_reg_data).ok());
|
|
return op_reg_data.op_def;
|
|
}
|
|
|
|
// Tests for shape inference
|
|
|
|
PartialTensorShape S(std::initializer_list<int64> dims) {
|
|
return PartialTensorShape(dims);
|
|
}
|
|
|
|
PartialTensorShape Unknown() { return PartialTensorShape(); }
|
|
|
|
TEST(OpsTest, ShapeInferenceWithRank) {
|
|
NodeDef def;
|
|
shape_inference::InferenceContext c(0, def, MakeOpDef(1, 0),
|
|
{S({10, 20, 30})}, {}, {}, {});
|
|
|
|
shape_inference::ShapeHandle in0 = c.input(0);
|
|
shape_inference::ShapeHandle s1;
|
|
|
|
TF_Status* status = TF_NewStatus();
|
|
TF_ShapeInferenceContextWithRankAtMost(C_CTX(&c), C_SHP(&in0), 3, C_SHP(&s1),
|
|
status);
|
|
EXPECT_EQ("[10,20,30]", c.DebugString(s1));
|
|
EXPECT_EQ(TF_OK, TF_GetCode(status));
|
|
|
|
TF_ShapeInferenceContextWithRankAtLeast(C_CTX(&c), C_SHP(&in0), 3, C_SHP(&s1),
|
|
status);
|
|
EXPECT_EQ("[10,20,30]", c.DebugString(s1));
|
|
EXPECT_EQ(TF_OK, TF_GetCode(status));
|
|
|
|
TF_ShapeInferenceContextWithRankAtLeast(C_CTX(&c), C_SHP(&in0), 6, C_SHP(&s1),
|
|
status);
|
|
ASSERT_NE(TF_OK, TF_GetCode(status));
|
|
|
|
TF_SetStatus(status, TF_OK, "");
|
|
TF_ShapeInferenceContextWithRankAtMost(C_CTX(&c), C_SHP(&in0), 1, C_SHP(&s1),
|
|
status);
|
|
ASSERT_NE(TF_OK, TF_GetCode(status));
|
|
|
|
TF_SetStatus(status, TF_OK, "");
|
|
TF_ShapeInferenceContextWithRank(C_CTX(&c), C_SHP(&in0), 3, C_SHP(&s1),
|
|
status);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(status));
|
|
|
|
TF_ShapeInferenceContextWithRank(C_CTX(&c), C_SHP(&in0), 4, C_SHP(&s1),
|
|
status);
|
|
ASSERT_NE(TF_OK, TF_GetCode(status));
|
|
|
|
TF_DeleteStatus(status);
|
|
}
|
|
|
|
TEST(OpsTest, ShapeInferenceWithRank_UnknownRank) {
|
|
NodeDef def;
|
|
shape_inference::InferenceContext c(0, def, MakeOpDef(2, 2),
|
|
{Unknown(), S({1, -1, 3})}, {}, {}, {});
|
|
|
|
shape_inference::ShapeHandle in0 = c.input(0);
|
|
shape_inference::ShapeHandle s1;
|
|
|
|
// WithRankAtMost and WithRankAtLeast on a shape with unknown dimensionality
|
|
// always succeed.
|
|
TF_Status* status = TF_NewStatus();
|
|
TF_ShapeInferenceContextWithRankAtMost(C_CTX(&c), C_SHP(&in0), 1, C_SHP(&s1),
|
|
status);
|
|
EXPECT_EQ("?", c.DebugString(s1));
|
|
EXPECT_EQ(TF_OK, TF_GetCode(status));
|
|
|
|
TF_ShapeInferenceContextWithRankAtLeast(C_CTX(&c), C_SHP(&in0), 1, C_SHP(&s1),
|
|
status);
|
|
EXPECT_EQ("?", c.DebugString(s1));
|
|
EXPECT_EQ(TF_OK, TF_GetCode(status));
|
|
|
|
TF_DeleteStatus(status);
|
|
}
|
|
|
|
TEST(OpsTest, ShapeInferenceConcatenateShapes) {
|
|
NodeDef def;
|
|
shape_inference::InferenceContext c(0, def, MakeOpDef(2, 0),
|
|
{S({1, 2}), S({3, 4})}, {}, {}, {});
|
|
ASSERT_EQ(2, TF_ShapeInferenceContextNumInputs(C_CTX(&c)));
|
|
shape_inference::ShapeHandle a = c.input(0);
|
|
shape_inference::ShapeHandle b = c.input(1);
|
|
TF_ShapeHandle* result = TF_NewShapeHandle();
|
|
TF_Status* status = TF_NewStatus();
|
|
TF_ShapeInferenceContextConcatenateShapes(C_CTX(&c), C_SHP(&a), C_SHP(&b),
|
|
result, status);
|
|
EXPECT_EQ(
|
|
"[1,2,3,4]",
|
|
c.DebugString(*reinterpret_cast<shape_inference::ShapeHandle*>(result)));
|
|
EXPECT_EQ(TF_OK, TF_GetCode(status));
|
|
TF_DeleteShapeHandle(result);
|
|
TF_DeleteStatus(status);
|
|
}
|
|
|
|
TEST(OpsTest, DimensionHandleValueKnown) {
|
|
NodeDef def;
|
|
shape_inference::InferenceContext c(0, def, MakeOpDef(2, 0),
|
|
{S({1, 2}), S({3, 4})}, {}, {}, {});
|
|
TF_ShapeHandle* handle =
|
|
TF_ShapeInferenceContextVectorFromSize(C_CTX(&c), 43);
|
|
ASSERT_EQ(
|
|
"[43]",
|
|
c.DebugString(*reinterpret_cast<shape_inference::ShapeHandle*>(handle)));
|
|
ASSERT_EQ(1, TF_ShapeInferenceContextRankKnown(C_CTX(&c), handle));
|
|
ASSERT_EQ(1, TF_ShapeInferenceContextRank(C_CTX(&c), handle));
|
|
|
|
TF_DimensionHandle* dim_handle = TF_NewDimensionHandle();
|
|
TF_ShapeInferenceContextDim(C_CTX(&c), handle, 0, dim_handle);
|
|
ASSERT_EQ(1, TF_DimensionHandleValueKnown(dim_handle));
|
|
ASSERT_EQ(43, TF_DimensionHandleValue(dim_handle));
|
|
TF_DeleteShapeHandle(handle);
|
|
TF_DeleteDimensionHandle(dim_handle);
|
|
}
|
|
|
|
TEST(OpsTest, ShapeInferenceSubshape) {
|
|
NodeDef def;
|
|
shape_inference::InferenceContext c(0, def, MakeOpDef(1, 0),
|
|
{S({10, 20, 30, 40, 50})}, {}, {}, {});
|
|
ASSERT_EQ("[10,20,30,40,50]", c.DebugString(c.input(0)));
|
|
|
|
TF_ShapeHandle* handle = TF_NewShapeHandle();
|
|
TF_Status* status = TF_NewStatus();
|
|
TF_ShapeInferenceContextGetInput(C_CTX(&c), 0, handle, status);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(status));
|
|
TF_ShapeInferenceContextSubshape(C_CTX(&c), handle, 1, -1, handle, status);
|
|
ASSERT_EQ(TF_OK, TF_GetCode(status));
|
|
ASSERT_EQ(
|
|
"[20,30,40]",
|
|
c.DebugString(*reinterpret_cast<shape_inference::ShapeHandle*>(handle)));
|
|
TF_DeleteStatus(status);
|
|
TF_DeleteShapeHandle(handle);
|
|
}
|
|
|
|
} // namespace
|
|
} // namespace tensorflow
|