Automated rollback of commit fa636543de3159b35e304f7db670d13df8af9878

PiperOrigin-RevId: 250567146
This commit is contained in:
James Ring 2019-05-29 14:28:23 -07:00 committed by TensorFlower Gardener
parent 0facf5a30f
commit b25e77ca06
4 changed files with 201 additions and 272 deletions

View File

@ -221,6 +221,7 @@ tf_cuda_library(
visibility = ["//visibility:public"],
deps = [
":c_api_internal",
":c_api_no_xla",
"//tensorflow/core:lib",
],
)
@ -527,13 +528,10 @@ tf_cc_test(
deps = [
":c_api",
":ops",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"@com_google_absl//absl/strings",
],
)

View File

@ -23,7 +23,6 @@ limitations under the License.
using ::tensorflow::DataType;
using ::tensorflow::OpDef;
using ::tensorflow::OpDefBuilder;
using ::tensorflow::OpDeprecation;
using ::tensorflow::OpShapeInferenceFn;
using ::tensorflow::Set_TF_Status_from_Status;
@ -32,53 +31,111 @@ using ::tensorflow::shape_inference::DimensionHandle;
using ::tensorflow::shape_inference::InferenceContext;
using ::tensorflow::shape_inference::ShapeHandle;
typedef struct TF_OpDefinitionBuilder {
// The op definition proto representing the op.
tensorflow::OpDef op_def;
// The shape inference function, or nullptr if none is provided for this op.
OpShapeInferenceFn shape_inference_func;
} TF_OpDefinitionBuilder;
TF_OpDefinitionBuilder* TF_NewOpDefinitionBuilder(const char* op_name) {
auto* result = new OpDefBuilder(op_name);
return reinterpret_cast<TF_OpDefinitionBuilder*>(result);
auto* result = new TF_OpDefinitionBuilder;
result->op_def.set_name(op_name);
return result;
}
void TF_DeleteOpDefinitionBuilder(TF_OpDefinitionBuilder* builder) {
delete reinterpret_cast<OpDefBuilder*>(builder);
delete builder;
}
static void PopulateArg(OpDef::ArgDef* arg, const char* name,
TF_DataType type) {
arg->set_name(name);
arg->set_type(static_cast<DataType>(type));
}
void TF_OpDefinitionBuilderAddInput(TF_OpDefinitionBuilder* builder,
const char* input_spec) {
reinterpret_cast<OpDefBuilder*>(builder)->Input(input_spec);
const char* name, TF_DataType type) {
PopulateArg(builder->op_def.add_input_arg(), name, type);
}
void TF_OpDefinitionBuilderAddOutput(TF_OpDefinitionBuilder* builder,
const char* output_spec) {
reinterpret_cast<OpDefBuilder*>(builder)->Output(output_spec);
const char* name, TF_DataType type) {
PopulateArg(builder->op_def.add_output_arg(), name, type);
}
#define DEFINE_BUILDER_BOOL_SETTER(func_name) \
#define DEFINE_BUILDER_BOOL_SETTER(func_name, builder_setter_name, arg_name) \
void TF_OpDefinitionBuilder##func_name(TF_OpDefinitionBuilder* builder, \
bool arg_name) { \
reinterpret_cast<OpDefBuilder*>(builder)->func_name(); \
builder->op_def.builder_setter_name(arg_name); \
}
DEFINE_BUILDER_BOOL_SETTER(SetIsCommutative)
DEFINE_BUILDER_BOOL_SETTER(SetIsAggregate)
DEFINE_BUILDER_BOOL_SETTER(SetIsStateful)
DEFINE_BUILDER_BOOL_SETTER(SetAllowsUninitializedInput)
DEFINE_BUILDER_BOOL_SETTER(SetIsCommutative, set_is_commutative, is_commutative)
DEFINE_BUILDER_BOOL_SETTER(SetIsAggregate, set_is_aggregate, is_aggregate)
DEFINE_BUILDER_BOOL_SETTER(SetIsStateful, set_is_stateful, is_stateful)
DEFINE_BUILDER_BOOL_SETTER(SetAllowsUninitializedInput,
set_allows_uninitialized_input,
allows_unintialized_input)
void TF_OpDefinitionBuilderAddAttr(TF_OpDefinitionBuilder* builder,
const char* attr_spec) {
reinterpret_cast<OpDefBuilder*>(builder)->Attr(attr_spec);
static OpDef::AttrDef* AddAttribute(TF_OpDefinitionBuilder* builder,
const char* name, const char* type_name) {
OpDef::AttrDef* attr = builder->op_def.add_attr();
attr->set_name(name);
attr->set_type(type_name);
return attr;
}
#define DEFINE_ATTR_SETTER(attr_type, type_name, field_c_type, field_name) \
void TF_OpDefinitionBuilderAdd##attr_type##Attr( \
TF_OpDefinitionBuilder* builder, const char* name) { \
AddAttribute(builder, name, type_name); \
} \
\
void TF_OpDefinitionBuilderAdd##attr_type##AttrWithDefaultValue( \
TF_OpDefinitionBuilder* builder, const char* name, \
field_c_type field_name) { \
OpDef::AttrDef* attr = AddAttribute(builder, name, type_name); \
attr->mutable_default_value()->set_##field_name(field_name); \
} \
\
void TF_OpDefinitionBuilderAdd##attr_type##ListAttrWithDefaultValues( \
TF_OpDefinitionBuilder* builder, const char* name, \
field_c_type field_name[], size_t n) { \
OpDef::AttrDef* attr = AddAttribute(builder, name, "list(" type_name ")"); \
for (int _i = 0; _i < n; ++_i) { \
attr->mutable_default_value()->mutable_list()->add_##field_name( \
field_name[_i]); \
} \
} \
\
void TF_OpDefinitionBuilderAdd##attr_type##ListAttr( \
TF_OpDefinitionBuilder* builder, const char* name) { \
TF_OpDefinitionBuilderAdd##attr_type##ListAttrWithDefaultValues( \
builder, name, NULL, 0); \
}
DEFINE_ATTR_SETTER(String, "string", const char*, s)
DEFINE_ATTR_SETTER(Int, "int", int64_t, i)
DEFINE_ATTR_SETTER(Float, "float", float, f)
DEFINE_ATTR_SETTER(Bool, "bool", bool, b)
void TF_OpDefinitionBuilderDeprecated(TF_OpDefinitionBuilder* builder,
int version, const char* explanation) {
reinterpret_cast<OpDefBuilder*>(builder)->Deprecated(version, explanation);
OpDeprecation* dep = builder->op_def.mutable_deprecation();
dep->set_version(version);
dep->set_explanation(explanation);
}
void TF_RegisterOpDefinition(TF_OpDefinitionBuilder* builder,
TF_Status* status) {
auto* cc_builder = reinterpret_cast<OpDefBuilder*>(builder);
TF_SetStatus(status, TF_OK, "");
::tensorflow::OpRegistry::Global()->Register(
[cc_builder](::tensorflow::OpRegistrationData* op_reg_data) -> Status {
return cc_builder->Finalize(op_reg_data);
[builder](::tensorflow::OpRegistrationData* op_reg_data) -> Status {
op_reg_data->op_def.Clear();
op_reg_data->op_def.MergeFrom(builder->op_def);
op_reg_data->shape_inference_fn = builder->shape_inference_func;
return Status::OK();
});
// Calling ProcessRegistrations ensures that the cc_builder's finalize method
@ -86,15 +143,14 @@ void TF_RegisterOpDefinition(TF_OpDefinitionBuilder* builder,
Set_TF_Status_from_Status(
status, ::tensorflow::OpRegistry::Global()->ProcessRegistrations());
delete cc_builder;
delete builder;
}
void TF_OpDefinitionBuilderSetShapeInferenceFunction(
TF_OpDefinitionBuilder* builder,
void (*shape_inference_func)(TF_ShapeInferenceContext* ctx,
TF_Status* status)) {
auto* cc_builder = reinterpret_cast<OpDefBuilder*>(builder);
cc_builder->SetShapeFn(
builder->shape_inference_func =
[shape_inference_func](InferenceContext* ctx) -> tensorflow::Status {
TF_Status* c_status = TF_NewStatus();
auto c_ctx = reinterpret_cast<TF_ShapeInferenceContext*>(ctx);
@ -102,7 +158,7 @@ void TF_OpDefinitionBuilderSetShapeInferenceFunction(
tensorflow::Status result = ::tensorflow::StatusFromTF_Status(c_status);
TF_DeleteStatus(c_status);
return result;
});
};
}
TF_ShapeHandle* TF_NewShapeHandle() {

View File

@ -125,65 +125,97 @@ TF_CAPI_EXPORT extern void TF_DeleteOpDefinitionBuilder(
//----------------------------------------------------
// Attribute functions.
// Adds an attr to the given TF_OpDefinitionBuilder. The spec has
// format "<name>:<type>" or "<name>:<type>=<default>"
// where <name> matches regexp [a-zA-Z][a-zA-Z0-9_]*.
// By convention, names containing only capital letters are reserved for
// attributes whose values can be inferred by the operator implementation if not
// supplied by the user. If the attribute name contains characters other than
// capital letters, the operator expects the user to provide the attribute value
// at operation runtime.
//
// <type> can be:
// "string", "int", "float", "bool", "type", "shape", or "tensor"
// "numbertype", "realnumbertype", "quantizedtype"
// (meaning "type" with a restriction on valid values)
// "{int32,int64}" or {realnumbertype,quantizedtype,string}"
// (meaning "type" with a restriction containing unions of value types)
// "{\"foo\", \"bar\n baz\"}", or "{'foo', 'bar\n baz'}"
// (meaning "string" with a restriction on valid values)
// "list(string)", ..., "list(tensor)", "list(numbertype)", ...
// (meaning lists of the above types)
// "int >= 2" (meaning "int" with a restriction on valid values)
// "list(string) >= 2", "list(int) >= 2"
// (meaning "list(string)" / "list(int)" with length at least 2)
// <default>, if included, should use the Proto text format
// of <type>. For lists use [a, b, c] format.
//
// Note that any attr specifying the length of an input or output will
// get a default minimum of 1 unless the >= # syntax is used.
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddAttr(
TF_OpDefinitionBuilder* builder, const char* attr_spec);
// Adds a string attribute with the given name to the builder.
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddStringAttr(
TF_OpDefinitionBuilder* builder, const char* name);
// Adds an input to this TF_OpDefinitionBuilder.
// The spec has form "<name>:<type-expr>" or "<name>:Ref(<type-expr>)"
// where <name> matches regexp [a-z][a-z0-9_]* and <type-expr> can be:
// * For a single tensor: <type>
// * For a sequence of tensors with the same type: <number>*<type>
// * For a sequence of tensors with different types: <type-list>
// Where:
// <type> is either one of "float", "int32", "string", ...
// or the name of an attr (see TF_OpDefinitionBuilderAddAttr)
// with type "type".
// <number> is the name of an attr with type "int".
// <type-list> is the name of an attr with type "list(type)".
// Adds a string attribute with the given name and default value to the builder.
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddStringAttrWithDefaultValue(
TF_OpDefinitionBuilder* builder, const char* name, const char* value);
// Adds a string list attribute with the given name and no default value to the
// builder.
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddStringListAttr(
TF_OpDefinitionBuilder* builder, const char* name);
// Adds a string list attribute with the given default values to the builder.
// `values` must contain at least `n` elements.
TF_CAPI_EXPORT extern void
TF_OpDefinitionBuilderAddStringListAttrWithDefaultValues(
TF_OpDefinitionBuilder* builder, const char* name, const char* values[],
size_t n);
// Adds an integer attribute with the given name and no default value to the
// builder.
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddIntAttr(
TF_OpDefinitionBuilder* builder, const char* name);
// Adds an integer attribute with the given name and default value to the
// builder.
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddIntAttrWithDefaultValue(
TF_OpDefinitionBuilder* builder, const char* name, int64_t value);
// Adds an integer list attribute with the given name and no default value to
// the builder.
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddIntListAttr(
TF_OpDefinitionBuilder* builder, const char* name);
// Adds an integer list attribute with the given name and default values to the
// builder. `values` must contain at least `n` elements.
TF_CAPI_EXPORT extern void
TF_OpDefinitionBuilderAddIntListAttrWithDefaultValues(
TF_OpDefinitionBuilder* builder, const char* name, int64_t values[],
size_t n);
// Adds a float attribute with the given name and no default value to the
// builder.
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddFloatAttr(
TF_OpDefinitionBuilder* builder, const char* name);
// Adds a float attribute with the given name and default value to the builder.
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddFloatAttrWithDefaultValue(
TF_OpDefinitionBuilder* builder, const char* name, float value);
// Adds a float list attribute with the given name and no default value to the
// builder.
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddFloatListAttr(
TF_OpDefinitionBuilder* builder, const char* name);
// Adds a float list attribute with the given name and default values to the
// builder. `values` must contain at least `n` elements.
TF_CAPI_EXPORT extern void
TF_OpDefinitionBuilderAddFloatListAttrWithDefaultValues(
TF_OpDefinitionBuilder* builder, const char* name, float values[],
size_t n);
// Adds a boolean attribute with the given name and no default value to the
// builder.
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddBoolAttr(
TF_OpDefinitionBuilder* builder, const char* name);
// Adds a boolean attribute with the given name and default value to the
// builder.
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddBoolAttrWithDefaultValue(
TF_OpDefinitionBuilder* builder, const char* name, bool value);
// Adds a boolean list attribute with the given name and no default value to the
// builder.
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddBoolListAttr(
TF_OpDefinitionBuilder* builder, const char* name);
// Adds a boolean list attribute with the given name and default values to the
// builder. `values` must contain at least `n` elements.
TF_CAPI_EXPORT extern void
TF_OpDefinitionBuilderAddBoolListAttrWithDefaultValues(
TF_OpDefinitionBuilder* builder, const char* name, bool values[], size_t n);
// Adds the input with the given name and type to the op.
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddInput(
TF_OpDefinitionBuilder* builder, const char* input_spec);
TF_OpDefinitionBuilder* builder, const char* name, TF_DataType type);
// Adds an output to this TF_OpDefinitionBuilder.
// The spec has form "<name>:<type-expr>" or "<name>:Ref(<type-expr>)"
// where <name> matches regexp [a-z][a-z0-9_]* and <type-expr> can be:
// * For a single tensor: <type>
// * For a sequence of tensors with the same type: <number>*<type>
// * For a sequence of tensors with different types: <type-list>
// Where:
// <type> is either one of "float", "int32", "string", ...
// or the name of an attr (see TF_OpDefinitionBuilderAddAttr)
// with type "type".
// <number> is the name of an attr with type "int".
// <type-list> is the name of an attr with type "list(type)".
// Adds the output with the given name and type to the op.
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddOutput(
TF_OpDefinitionBuilder* builder, const char* output_spec);
TF_OpDefinitionBuilder* builder, const char* output, TF_DataType type);
// Sets the commutative property for the op built by the given builder.
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderSetIsCommutative(
@ -331,6 +363,10 @@ TF_CAPI_EXPORT extern void TF_ShapeInferenceContextDim(
TF_ShapeInferenceContext* ctx, TF_ShapeHandle* shape_handle, int64_t i,
TF_DimensionHandle* result);
// Returns 1 if the given handle represents a known dimension.
TF_CAPI_EXPORT extern int TF_ShapeInferenceContextDimValueKnown(
TF_ShapeInferenceContext* ctx, TF_DimensionHandle* handle);
// Returns in <*result> a sub-shape of <shape_handle>, with dimensions
// [start:end]. <start> and <end> can be negative, to index from the end of the
// shape. <start> and <end> are set to the rank of <shape_handle> if > rank of

View File

@ -15,14 +15,10 @@ limitations under the License.
#include "tensorflow/c/ops.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/op_def_builder.h"
#include "tensorflow/core/framework/shape_inference_testutil.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
@ -32,10 +28,10 @@ namespace {
TEST(OpsTest, TestBasicOpRegistration) {
TF_OpDefinitionBuilder* builder = TF_NewOpDefinitionBuilder("SomeOp");
TF_OpDefinitionBuilderAddAttr(builder, "attr1: string");
TF_OpDefinitionBuilderAddInput(builder, "input1: uint8");
TF_OpDefinitionBuilderAddInput(builder, "input2: uint16");
TF_OpDefinitionBuilderAddOutput(builder, "output1: uint32");
TF_OpDefinitionBuilderAddStringAttr(builder, "attr1");
TF_OpDefinitionBuilderAddInput(builder, "input1", TF_UINT8);
TF_OpDefinitionBuilderAddInput(builder, "input2", TF_UINT16);
TF_OpDefinitionBuilderAddOutput(builder, "output1", TF_UINT32);
TF_Status* status = TF_NewStatus();
TF_RegisterOpDefinition(builder, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
@ -70,8 +66,8 @@ TEST(OpsTest, TestShapeInference_IdentityFunction) {
ShapeInferenceTestOp op("SomeTestOp");
TF_OpDefinitionBuilder* builder = TF_NewOpDefinitionBuilder("SomeTestOp");
TF_OpDefinitionBuilderAddInput(builder, "input1: uint8");
TF_OpDefinitionBuilderAddOutput(builder, "output1: uint8");
TF_OpDefinitionBuilderAddInput(builder, "input1", TF_UINT8);
TF_OpDefinitionBuilderAddOutput(builder, "output1", TF_UINT8);
TF_OpDefinitionBuilderSetShapeInferenceFunction(builder, &identity_shape_fn);
TF_Status* status = TF_NewStatus();
TF_RegisterOpDefinition(builder, status);
@ -82,25 +78,6 @@ TEST(OpsTest, TestShapeInference_IdentityFunction) {
TF_DeleteStatus(status);
}
TEST(OpsTest, TestShapeInference_UnknownShape) {
ShapeInferenceTestOp op("UnknownShapeOp");
TF_OpDefinitionBuilder* builder = TF_NewOpDefinitionBuilder("UnknownShapeOp");
TF_OpDefinitionBuilderAddInput(builder, "input1: uint8");
TF_OpDefinitionBuilderAddInput(builder, "input2: uint32");
TF_OpDefinitionBuilderAddOutput(builder, "output1: uint8");
TF_OpDefinitionBuilderAddOutput(builder, "output2: uint8");
TF_OpDefinitionBuilderSetShapeInferenceFunction(
builder, &TF_ShapeInferenceContextSetUnknownShape);
TF_Status* status = TF_NewStatus();
TF_RegisterOpDefinition(builder, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_ASSERT_OK(shape_inference::ShapeInferenceTestutil::InferShapes(
op, "[1,2];[3,4]", "?;?"));
TF_DeleteStatus(status);
}
// Creates an output whose shape is a vector of length
// TF_ShapeInferenceContextRank.
void vectorize_shape_fn(TF_ShapeInferenceContext* ctx, TF_Status* status) {
@ -119,8 +96,8 @@ TEST(OpsTest, TestShapeInference_VectorizeFunction) {
TF_OpDefinitionBuilder* builder =
TF_NewOpDefinitionBuilder("VectorizeTestOp");
TF_OpDefinitionBuilderAddInput(builder, "input1: uint8");
TF_OpDefinitionBuilderAddOutput(builder, "output1: uint8");
TF_OpDefinitionBuilderAddInput(builder, "input1", TF_UINT8);
TF_OpDefinitionBuilderAddOutput(builder, "output1", TF_UINT8);
TF_OpDefinitionBuilderSetShapeInferenceFunction(builder, &vectorize_shape_fn);
TF_Status* status = TF_NewStatus();
TF_RegisterOpDefinition(builder, status);
@ -134,8 +111,11 @@ TEST(OpsTest, TestShapeInference_VectorizeFunction) {
TEST(OpsTest, AttributeAccessors) {
TF_OpDefinitionBuilder* builder =
TF_NewOpDefinitionBuilder("AttributeAccesorsOp");
TF_OpDefinitionBuilderAddAttr(builder, "foo1: int >= 2");
TF_OpDefinitionBuilderAddAttr(builder, "foo2: string=\"my string\"");
float values[] = {1, 2, 3, 4};
TF_OpDefinitionBuilderAddFloatListAttrWithDefaultValues(
builder, "foo1", values, sizeof(values));
TF_OpDefinitionBuilderAddStringAttrWithDefaultValue(builder, "foo2",
"my string");
TF_OpDefinitionBuilderSetIsCommutative(builder, true);
TF_OpDefinitionBuilderSetIsAggregate(builder, true);
TF_OpDefinitionBuilderSetAllowsUninitializedInput(builder, true);
@ -158,8 +138,13 @@ TEST(OpsTest, AttributeAccessors) {
ASSERT_EQ(4, op.deprecation().version());
ASSERT_EQ(deprecation_msg, op.deprecation().explanation());
ASSERT_EQ(2, op.attr_size());
ASSERT_EQ("int", op.attr(0).type());
ASSERT_EQ(2, op.attr(0).minimum());
ASSERT_EQ("list(float)", op.attr(0).type());
AttrValue::ListValue l = op.attr(0).default_value().list();
ASSERT_EQ(1, l.f(0));
ASSERT_EQ(2, l.f(1));
ASSERT_EQ(3, l.f(2));
ASSERT_EQ(4, l.f(3));
ASSERT_EQ("string", op.attr(1).type());
ASSERT_EQ("my string", op.attr(1).default_value().s());
found = true;
@ -170,151 +155,5 @@ TEST(OpsTest, AttributeAccessors) {
TF_DeleteBuffer(op_list_buffer);
}
#define C_CTX(x) reinterpret_cast<TF_ShapeInferenceContext*>(x)
#define C_SHP(x) reinterpret_cast<TF_ShapeHandle*>(x)
static OpDef MakeOpDef(int num_inputs, int num_outputs) {
OpRegistrationData op_reg_data;
OpDefBuilder b("dummy");
for (int i = 0; i < num_inputs; ++i) {
b.Input(strings::StrCat("i", i, ": float"));
}
for (int i = 0; i < num_outputs; ++i) {
b.Output(strings::StrCat("o", i, ": float"));
}
CHECK(b.Attr("foo:string").Finalize(&op_reg_data).ok());
return op_reg_data.op_def;
}
// Tests for shape inference
PartialTensorShape S(std::initializer_list<int64> dims) {
return PartialTensorShape(dims);
}
PartialTensorShape Unknown() { return PartialTensorShape(); }
TEST(OpsTest, ShapeInferenceWithRank) {
NodeDef def;
shape_inference::InferenceContext c(0, &def, MakeOpDef(1, 0),
{S({10, 20, 30})}, {}, {}, {});
shape_inference::ShapeHandle in0 = c.input(0);
shape_inference::ShapeHandle s1;
TF_Status* status = TF_NewStatus();
TF_ShapeInferenceContextWithRankAtMost(C_CTX(&c), C_SHP(&in0), 3, C_SHP(&s1),
status);
EXPECT_EQ("[10,20,30]", c.DebugString(s1));
EXPECT_EQ(TF_OK, TF_GetCode(status));
TF_ShapeInferenceContextWithRankAtLeast(C_CTX(&c), C_SHP(&in0), 3, C_SHP(&s1),
status);
EXPECT_EQ("[10,20,30]", c.DebugString(s1));
EXPECT_EQ(TF_OK, TF_GetCode(status));
TF_ShapeInferenceContextWithRankAtLeast(C_CTX(&c), C_SHP(&in0), 6, C_SHP(&s1),
status);
ASSERT_NE(TF_OK, TF_GetCode(status));
TF_SetStatus(status, TF_OK, "");
TF_ShapeInferenceContextWithRankAtMost(C_CTX(&c), C_SHP(&in0), 1, C_SHP(&s1),
status);
ASSERT_NE(TF_OK, TF_GetCode(status));
TF_SetStatus(status, TF_OK, "");
TF_ShapeInferenceContextWithRank(C_CTX(&c), C_SHP(&in0), 3, C_SHP(&s1),
status);
ASSERT_EQ(TF_OK, TF_GetCode(status));
TF_ShapeInferenceContextWithRank(C_CTX(&c), C_SHP(&in0), 4, C_SHP(&s1),
status);
ASSERT_NE(TF_OK, TF_GetCode(status));
TF_DeleteStatus(status);
}
TEST(OpsTest, ShapeInferenceWithRank_UnknownRank) {
NodeDef def;
shape_inference::InferenceContext c(0, &def, MakeOpDef(2, 2),
{Unknown(), S({1, -1, 3})}, {}, {}, {});
shape_inference::ShapeHandle in0 = c.input(0);
shape_inference::ShapeHandle s1;
// WithRankAtMost and WithRankAtLeast on a shape with unknown dimensionality
// always succeed.
TF_Status* status = TF_NewStatus();
TF_ShapeInferenceContextWithRankAtMost(C_CTX(&c), C_SHP(&in0), 1, C_SHP(&s1),
status);
EXPECT_EQ("?", c.DebugString(s1));
EXPECT_EQ(TF_OK, TF_GetCode(status));
TF_ShapeInferenceContextWithRankAtLeast(C_CTX(&c), C_SHP(&in0), 1, C_SHP(&s1),
status);
EXPECT_EQ("?", c.DebugString(s1));
EXPECT_EQ(TF_OK, TF_GetCode(status));
TF_DeleteStatus(status);
}
TEST(OpsTest, ShapeInferenceConcatenateShapes) {
NodeDef def;
shape_inference::InferenceContext c(0, &def, MakeOpDef(2, 0),
{S({1, 2}), S({3, 4})}, {}, {}, {});
ASSERT_EQ(2, TF_ShapeInferenceContextNumInputs(C_CTX(&c)));
shape_inference::ShapeHandle a = c.input(0);
shape_inference::ShapeHandle b = c.input(1);
TF_ShapeHandle* result = TF_NewShapeHandle();
TF_Status* status = TF_NewStatus();
TF_ShapeInferenceContextConcatenateShapes(C_CTX(&c), C_SHP(&a), C_SHP(&b),
result, status);
EXPECT_EQ(
"[1,2,3,4]",
c.DebugString(*reinterpret_cast<shape_inference::ShapeHandle*>(result)));
EXPECT_EQ(TF_OK, TF_GetCode(status));
TF_DeleteShapeHandle(result);
TF_DeleteStatus(status);
}
TEST(OpsTest, DimensionHandleValueKnown) {
NodeDef def;
shape_inference::InferenceContext c(0, &def, MakeOpDef(2, 0),
{S({1, 2}), S({3, 4})}, {}, {}, {});
TF_ShapeHandle* handle =
TF_ShapeInferenceContextVectorFromSize(C_CTX(&c), 43);
ASSERT_EQ(
"[43]",
c.DebugString(*reinterpret_cast<shape_inference::ShapeHandle*>(handle)));
ASSERT_EQ(1, TF_ShapeInferenceContextRankKnown(C_CTX(&c), handle));
ASSERT_EQ(1, TF_ShapeInferenceContextRank(C_CTX(&c), handle));
TF_DimensionHandle* dim_handle = TF_NewDimensionHandle();
TF_ShapeInferenceContextDim(C_CTX(&c), handle, 0, dim_handle);
ASSERT_EQ(1, TF_DimensionHandleValueKnown(dim_handle));
ASSERT_EQ(43, TF_DimensionHandleValue(dim_handle));
TF_DeleteShapeHandle(handle);
TF_DeleteDimensionHandle(dim_handle);
}
TEST(OpsTest, ShapeInferenceSubshape) {
NodeDef def;
shape_inference::InferenceContext c(0, &def, MakeOpDef(1, 0),
{S({10, 20, 30, 40, 50})}, {}, {}, {});
ASSERT_EQ("[10,20,30,40,50]", c.DebugString(c.input(0)));
TF_ShapeHandle* handle = TF_NewShapeHandle();
TF_Status* status = TF_NewStatus();
TF_ShapeInferenceContextGetInput(C_CTX(&c), 0, handle, status);
ASSERT_EQ(TF_OK, TF_GetCode(status));
TF_ShapeInferenceContextSubshape(C_CTX(&c), handle, 1, -1, handle, status);
ASSERT_EQ(TF_OK, TF_GetCode(status));
ASSERT_EQ(
"[20,30,40]",
c.DebugString(*reinterpret_cast<shape_inference::ShapeHandle*>(handle)));
TF_DeleteStatus(status);
TF_DeleteShapeHandle(handle);
}
} // namespace
} // namespace tensorflow