diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 1c0c8f33b2a..99eb28c1295 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -221,6 +221,7 @@ tf_cuda_library( visibility = ["//visibility:public"], deps = [ ":c_api_internal", + ":c_api_no_xla", "//tensorflow/core:lib", ], ) @@ -527,13 +528,10 @@ tf_cc_test( deps = [ ":c_api", ":ops", - "//tensorflow/core:framework", - "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", - "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/c/ops.cc b/tensorflow/c/ops.cc index d806a16cbc0..b175d262c01 100644 --- a/tensorflow/c/ops.cc +++ b/tensorflow/c/ops.cc @@ -23,7 +23,6 @@ limitations under the License. using ::tensorflow::DataType; using ::tensorflow::OpDef; -using ::tensorflow::OpDefBuilder; using ::tensorflow::OpDeprecation; using ::tensorflow::OpShapeInferenceFn; using ::tensorflow::Set_TF_Status_from_Status; @@ -32,53 +31,111 @@ using ::tensorflow::shape_inference::DimensionHandle; using ::tensorflow::shape_inference::InferenceContext; using ::tensorflow::shape_inference::ShapeHandle; +typedef struct TF_OpDefinitionBuilder { + // The op definition proto representing the op. + tensorflow::OpDef op_def; + + // The shape inference function, or nullptr if none is provided for this op. + OpShapeInferenceFn shape_inference_func; +} TF_OpDefinitionBuilder; + TF_OpDefinitionBuilder* TF_NewOpDefinitionBuilder(const char* op_name) { - auto* result = new OpDefBuilder(op_name); - return reinterpret_cast(result); + auto* result = new TF_OpDefinitionBuilder; + result->op_def.set_name(op_name); + return result; } void TF_DeleteOpDefinitionBuilder(TF_OpDefinitionBuilder* builder) { - delete reinterpret_cast(builder); + delete builder; +} + +static void PopulateArg(OpDef::ArgDef* arg, const char* name, + TF_DataType type) { + arg->set_name(name); + arg->set_type(static_cast(type)); } void TF_OpDefinitionBuilderAddInput(TF_OpDefinitionBuilder* builder, - const char* input_spec) { - reinterpret_cast(builder)->Input(input_spec); + const char* name, TF_DataType type) { + PopulateArg(builder->op_def.add_input_arg(), name, type); } void TF_OpDefinitionBuilderAddOutput(TF_OpDefinitionBuilder* builder, - const char* output_spec) { - reinterpret_cast(builder)->Output(output_spec); + const char* name, TF_DataType type) { + PopulateArg(builder->op_def.add_output_arg(), name, type); } -#define DEFINE_BUILDER_BOOL_SETTER(func_name) \ - void TF_OpDefinitionBuilder##func_name(TF_OpDefinitionBuilder* builder, \ - bool arg_name) { \ - reinterpret_cast(builder)->func_name(); \ +#define DEFINE_BUILDER_BOOL_SETTER(func_name, builder_setter_name, arg_name) \ + void TF_OpDefinitionBuilder##func_name(TF_OpDefinitionBuilder* builder, \ + bool arg_name) { \ + builder->op_def.builder_setter_name(arg_name); \ } -DEFINE_BUILDER_BOOL_SETTER(SetIsCommutative) -DEFINE_BUILDER_BOOL_SETTER(SetIsAggregate) -DEFINE_BUILDER_BOOL_SETTER(SetIsStateful) -DEFINE_BUILDER_BOOL_SETTER(SetAllowsUninitializedInput) +DEFINE_BUILDER_BOOL_SETTER(SetIsCommutative, set_is_commutative, is_commutative) +DEFINE_BUILDER_BOOL_SETTER(SetIsAggregate, set_is_aggregate, is_aggregate) +DEFINE_BUILDER_BOOL_SETTER(SetIsStateful, set_is_stateful, is_stateful) +DEFINE_BUILDER_BOOL_SETTER(SetAllowsUninitializedInput, + set_allows_uninitialized_input, + allows_unintialized_input) -void TF_OpDefinitionBuilderAddAttr(TF_OpDefinitionBuilder* builder, - const char* attr_spec) { - reinterpret_cast(builder)->Attr(attr_spec); +static OpDef::AttrDef* AddAttribute(TF_OpDefinitionBuilder* builder, + const char* name, const char* type_name) { + OpDef::AttrDef* attr = builder->op_def.add_attr(); + attr->set_name(name); + attr->set_type(type_name); + return attr; } +#define DEFINE_ATTR_SETTER(attr_type, type_name, field_c_type, field_name) \ + void TF_OpDefinitionBuilderAdd##attr_type##Attr( \ + TF_OpDefinitionBuilder* builder, const char* name) { \ + AddAttribute(builder, name, type_name); \ + } \ + \ + void TF_OpDefinitionBuilderAdd##attr_type##AttrWithDefaultValue( \ + TF_OpDefinitionBuilder* builder, const char* name, \ + field_c_type field_name) { \ + OpDef::AttrDef* attr = AddAttribute(builder, name, type_name); \ + attr->mutable_default_value()->set_##field_name(field_name); \ + } \ + \ + void TF_OpDefinitionBuilderAdd##attr_type##ListAttrWithDefaultValues( \ + TF_OpDefinitionBuilder* builder, const char* name, \ + field_c_type field_name[], size_t n) { \ + OpDef::AttrDef* attr = AddAttribute(builder, name, "list(" type_name ")"); \ + for (int _i = 0; _i < n; ++_i) { \ + attr->mutable_default_value()->mutable_list()->add_##field_name( \ + field_name[_i]); \ + } \ + } \ + \ + void TF_OpDefinitionBuilderAdd##attr_type##ListAttr( \ + TF_OpDefinitionBuilder* builder, const char* name) { \ + TF_OpDefinitionBuilderAdd##attr_type##ListAttrWithDefaultValues( \ + builder, name, NULL, 0); \ + } + +DEFINE_ATTR_SETTER(String, "string", const char*, s) +DEFINE_ATTR_SETTER(Int, "int", int64_t, i) +DEFINE_ATTR_SETTER(Float, "float", float, f) +DEFINE_ATTR_SETTER(Bool, "bool", bool, b) + void TF_OpDefinitionBuilderDeprecated(TF_OpDefinitionBuilder* builder, int version, const char* explanation) { - reinterpret_cast(builder)->Deprecated(version, explanation); + OpDeprecation* dep = builder->op_def.mutable_deprecation(); + dep->set_version(version); + dep->set_explanation(explanation); } void TF_RegisterOpDefinition(TF_OpDefinitionBuilder* builder, TF_Status* status) { - auto* cc_builder = reinterpret_cast(builder); TF_SetStatus(status, TF_OK, ""); ::tensorflow::OpRegistry::Global()->Register( - [cc_builder](::tensorflow::OpRegistrationData* op_reg_data) -> Status { - return cc_builder->Finalize(op_reg_data); + [builder](::tensorflow::OpRegistrationData* op_reg_data) -> Status { + op_reg_data->op_def.Clear(); + op_reg_data->op_def.MergeFrom(builder->op_def); + op_reg_data->shape_inference_fn = builder->shape_inference_func; + return Status::OK(); }); // Calling ProcessRegistrations ensures that the cc_builder's finalize method @@ -86,23 +143,22 @@ void TF_RegisterOpDefinition(TF_OpDefinitionBuilder* builder, Set_TF_Status_from_Status( status, ::tensorflow::OpRegistry::Global()->ProcessRegistrations()); - delete cc_builder; + delete builder; } void TF_OpDefinitionBuilderSetShapeInferenceFunction( TF_OpDefinitionBuilder* builder, void (*shape_inference_func)(TF_ShapeInferenceContext* ctx, TF_Status* status)) { - auto* cc_builder = reinterpret_cast(builder); - cc_builder->SetShapeFn( + builder->shape_inference_func = [shape_inference_func](InferenceContext* ctx) -> tensorflow::Status { - TF_Status* c_status = TF_NewStatus(); - auto c_ctx = reinterpret_cast(ctx); - shape_inference_func(c_ctx, c_status); - tensorflow::Status result = ::tensorflow::StatusFromTF_Status(c_status); - TF_DeleteStatus(c_status); - return result; - }); + TF_Status* c_status = TF_NewStatus(); + auto c_ctx = reinterpret_cast(ctx); + shape_inference_func(c_ctx, c_status); + tensorflow::Status result = ::tensorflow::StatusFromTF_Status(c_status); + TF_DeleteStatus(c_status); + return result; + }; } TF_ShapeHandle* TF_NewShapeHandle() { diff --git a/tensorflow/c/ops.h b/tensorflow/c/ops.h index 6f941a06fbc..7e2e95084ea 100644 --- a/tensorflow/c/ops.h +++ b/tensorflow/c/ops.h @@ -125,65 +125,97 @@ TF_CAPI_EXPORT extern void TF_DeleteOpDefinitionBuilder( //---------------------------------------------------- // Attribute functions. -// Adds an attr to the given TF_OpDefinitionBuilder. The spec has -// format ":" or ":=" -// where matches regexp [a-zA-Z][a-zA-Z0-9_]*. -// By convention, names containing only capital letters are reserved for -// attributes whose values can be inferred by the operator implementation if not -// supplied by the user. If the attribute name contains characters other than -// capital letters, the operator expects the user to provide the attribute value -// at operation runtime. -// -// can be: -// "string", "int", "float", "bool", "type", "shape", or "tensor" -// "numbertype", "realnumbertype", "quantizedtype" -// (meaning "type" with a restriction on valid values) -// "{int32,int64}" or {realnumbertype,quantizedtype,string}" -// (meaning "type" with a restriction containing unions of value types) -// "{\"foo\", \"bar\n baz\"}", or "{'foo', 'bar\n baz'}" -// (meaning "string" with a restriction on valid values) -// "list(string)", ..., "list(tensor)", "list(numbertype)", ... -// (meaning lists of the above types) -// "int >= 2" (meaning "int" with a restriction on valid values) -// "list(string) >= 2", "list(int) >= 2" -// (meaning "list(string)" / "list(int)" with length at least 2) -// , if included, should use the Proto text format -// of . For lists use [a, b, c] format. -// -// Note that any attr specifying the length of an input or output will -// get a default minimum of 1 unless the >= # syntax is used. -TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddAttr( - TF_OpDefinitionBuilder* builder, const char* attr_spec); +// Adds a string attribute with the given name to the builder. +TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddStringAttr( + TF_OpDefinitionBuilder* builder, const char* name); -// Adds an input to this TF_OpDefinitionBuilder. -// The spec has form ":" or ":Ref()" -// where matches regexp [a-z][a-z0-9_]* and can be: -// * For a single tensor: -// * For a sequence of tensors with the same type: * -// * For a sequence of tensors with different types: -// Where: -// is either one of "float", "int32", "string", ... -// or the name of an attr (see TF_OpDefinitionBuilderAddAttr) -// with type "type". -// is the name of an attr with type "int". -// is the name of an attr with type "list(type)". +// Adds a string attribute with the given name and default value to the builder. +TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddStringAttrWithDefaultValue( + TF_OpDefinitionBuilder* builder, const char* name, const char* value); + +// Adds a string list attribute with the given name and no default value to the +// builder. +TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddStringListAttr( + TF_OpDefinitionBuilder* builder, const char* name); + +// Adds a string list attribute with the given default values to the builder. +// `values` must contain at least `n` elements. +TF_CAPI_EXPORT extern void +TF_OpDefinitionBuilderAddStringListAttrWithDefaultValues( + TF_OpDefinitionBuilder* builder, const char* name, const char* values[], + size_t n); + +// Adds an integer attribute with the given name and no default value to the +// builder. +TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddIntAttr( + TF_OpDefinitionBuilder* builder, const char* name); + +// Adds an integer attribute with the given name and default value to the +// builder. +TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddIntAttrWithDefaultValue( + TF_OpDefinitionBuilder* builder, const char* name, int64_t value); + +// Adds an integer list attribute with the given name and no default value to +// the builder. +TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddIntListAttr( + TF_OpDefinitionBuilder* builder, const char* name); + +// Adds an integer list attribute with the given name and default values to the +// builder. `values` must contain at least `n` elements. +TF_CAPI_EXPORT extern void +TF_OpDefinitionBuilderAddIntListAttrWithDefaultValues( + TF_OpDefinitionBuilder* builder, const char* name, int64_t values[], + size_t n); + +// Adds a float attribute with the given name and no default value to the +// builder. +TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddFloatAttr( + TF_OpDefinitionBuilder* builder, const char* name); + +// Adds a float attribute with the given name and default value to the builder. +TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddFloatAttrWithDefaultValue( + TF_OpDefinitionBuilder* builder, const char* name, float value); + +// Adds a float list attribute with the given name and no default value to the +// builder. +TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddFloatListAttr( + TF_OpDefinitionBuilder* builder, const char* name); + +// Adds a float list attribute with the given name and default values to the +// builder. `values` must contain at least `n` elements. +TF_CAPI_EXPORT extern void +TF_OpDefinitionBuilderAddFloatListAttrWithDefaultValues( + TF_OpDefinitionBuilder* builder, const char* name, float values[], + size_t n); + +// Adds a boolean attribute with the given name and no default value to the +// builder. +TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddBoolAttr( + TF_OpDefinitionBuilder* builder, const char* name); + +// Adds a boolean attribute with the given name and default value to the +// builder. +TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddBoolAttrWithDefaultValue( + TF_OpDefinitionBuilder* builder, const char* name, bool value); + +// Adds a boolean list attribute with the given name and no default value to the +// builder. +TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddBoolListAttr( + TF_OpDefinitionBuilder* builder, const char* name); + +// Adds a boolean list attribute with the given name and default values to the +// builder. `values` must contain at least `n` elements. +TF_CAPI_EXPORT extern void +TF_OpDefinitionBuilderAddBoolListAttrWithDefaultValues( + TF_OpDefinitionBuilder* builder, const char* name, bool values[], size_t n); + +// Adds the input with the given name and type to the op. TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddInput( - TF_OpDefinitionBuilder* builder, const char* input_spec); + TF_OpDefinitionBuilder* builder, const char* name, TF_DataType type); -// Adds an output to this TF_OpDefinitionBuilder. -// The spec has form ":" or ":Ref()" -// where matches regexp [a-z][a-z0-9_]* and can be: -// * For a single tensor: -// * For a sequence of tensors with the same type: * -// * For a sequence of tensors with different types: -// Where: -// is either one of "float", "int32", "string", ... -// or the name of an attr (see TF_OpDefinitionBuilderAddAttr) -// with type "type". -// is the name of an attr with type "int". -// is the name of an attr with type "list(type)". +// Adds the output with the given name and type to the op. TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddOutput( - TF_OpDefinitionBuilder* builder, const char* output_spec); + TF_OpDefinitionBuilder* builder, const char* output, TF_DataType type); // Sets the commutative property for the op built by the given builder. TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderSetIsCommutative( @@ -331,6 +363,10 @@ TF_CAPI_EXPORT extern void TF_ShapeInferenceContextDim( TF_ShapeInferenceContext* ctx, TF_ShapeHandle* shape_handle, int64_t i, TF_DimensionHandle* result); +// Returns 1 if the given handle represents a known dimension. +TF_CAPI_EXPORT extern int TF_ShapeInferenceContextDimValueKnown( + TF_ShapeInferenceContext* ctx, TF_DimensionHandle* handle); + // Returns in <*result> a sub-shape of , with dimensions // [start:end]. and can be negative, to index from the end of the // shape. and are set to the rank of if > rank of diff --git a/tensorflow/c/ops_test.cc b/tensorflow/c/ops_test.cc index 0a6c5cd50fb..2b40f96157e 100644 --- a/tensorflow/c/ops_test.cc +++ b/tensorflow/c/ops_test.cc @@ -15,14 +15,10 @@ limitations under the License. #include "tensorflow/c/ops.h" -#include "absl/strings/str_cat.h" #include "tensorflow/c/c_api.h" #include "tensorflow/core/framework/attr_value.pb.h" -#include "tensorflow/core/framework/fake_input.h" #include "tensorflow/core/framework/op_def.pb.h" -#include "tensorflow/core/framework/op_def_builder.h" #include "tensorflow/core/framework/shape_inference_testutil.h" -#include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -32,10 +28,10 @@ namespace { TEST(OpsTest, TestBasicOpRegistration) { TF_OpDefinitionBuilder* builder = TF_NewOpDefinitionBuilder("SomeOp"); - TF_OpDefinitionBuilderAddAttr(builder, "attr1: string"); - TF_OpDefinitionBuilderAddInput(builder, "input1: uint8"); - TF_OpDefinitionBuilderAddInput(builder, "input2: uint16"); - TF_OpDefinitionBuilderAddOutput(builder, "output1: uint32"); + TF_OpDefinitionBuilderAddStringAttr(builder, "attr1"); + TF_OpDefinitionBuilderAddInput(builder, "input1", TF_UINT8); + TF_OpDefinitionBuilderAddInput(builder, "input2", TF_UINT16); + TF_OpDefinitionBuilderAddOutput(builder, "output1", TF_UINT32); TF_Status* status = TF_NewStatus(); TF_RegisterOpDefinition(builder, status); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); @@ -70,8 +66,8 @@ TEST(OpsTest, TestShapeInference_IdentityFunction) { ShapeInferenceTestOp op("SomeTestOp"); TF_OpDefinitionBuilder* builder = TF_NewOpDefinitionBuilder("SomeTestOp"); - TF_OpDefinitionBuilderAddInput(builder, "input1: uint8"); - TF_OpDefinitionBuilderAddOutput(builder, "output1: uint8"); + TF_OpDefinitionBuilderAddInput(builder, "input1", TF_UINT8); + TF_OpDefinitionBuilderAddOutput(builder, "output1", TF_UINT8); TF_OpDefinitionBuilderSetShapeInferenceFunction(builder, &identity_shape_fn); TF_Status* status = TF_NewStatus(); TF_RegisterOpDefinition(builder, status); @@ -82,25 +78,6 @@ TEST(OpsTest, TestShapeInference_IdentityFunction) { TF_DeleteStatus(status); } -TEST(OpsTest, TestShapeInference_UnknownShape) { - ShapeInferenceTestOp op("UnknownShapeOp"); - - TF_OpDefinitionBuilder* builder = TF_NewOpDefinitionBuilder("UnknownShapeOp"); - TF_OpDefinitionBuilderAddInput(builder, "input1: uint8"); - TF_OpDefinitionBuilderAddInput(builder, "input2: uint32"); - TF_OpDefinitionBuilderAddOutput(builder, "output1: uint8"); - TF_OpDefinitionBuilderAddOutput(builder, "output2: uint8"); - TF_OpDefinitionBuilderSetShapeInferenceFunction( - builder, &TF_ShapeInferenceContextSetUnknownShape); - TF_Status* status = TF_NewStatus(); - TF_RegisterOpDefinition(builder, status); - ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - - TF_ASSERT_OK(shape_inference::ShapeInferenceTestutil::InferShapes( - op, "[1,2];[3,4]", "?;?")); - TF_DeleteStatus(status); -} - // Creates an output whose shape is a vector of length // TF_ShapeInferenceContextRank. void vectorize_shape_fn(TF_ShapeInferenceContext* ctx, TF_Status* status) { @@ -119,8 +96,8 @@ TEST(OpsTest, TestShapeInference_VectorizeFunction) { TF_OpDefinitionBuilder* builder = TF_NewOpDefinitionBuilder("VectorizeTestOp"); - TF_OpDefinitionBuilderAddInput(builder, "input1: uint8"); - TF_OpDefinitionBuilderAddOutput(builder, "output1: uint8"); + TF_OpDefinitionBuilderAddInput(builder, "input1", TF_UINT8); + TF_OpDefinitionBuilderAddOutput(builder, "output1", TF_UINT8); TF_OpDefinitionBuilderSetShapeInferenceFunction(builder, &vectorize_shape_fn); TF_Status* status = TF_NewStatus(); TF_RegisterOpDefinition(builder, status); @@ -134,8 +111,11 @@ TEST(OpsTest, TestShapeInference_VectorizeFunction) { TEST(OpsTest, AttributeAccessors) { TF_OpDefinitionBuilder* builder = TF_NewOpDefinitionBuilder("AttributeAccesorsOp"); - TF_OpDefinitionBuilderAddAttr(builder, "foo1: int >= 2"); - TF_OpDefinitionBuilderAddAttr(builder, "foo2: string=\"my string\""); + float values[] = {1, 2, 3, 4}; + TF_OpDefinitionBuilderAddFloatListAttrWithDefaultValues( + builder, "foo1", values, sizeof(values)); + TF_OpDefinitionBuilderAddStringAttrWithDefaultValue(builder, "foo2", + "my string"); TF_OpDefinitionBuilderSetIsCommutative(builder, true); TF_OpDefinitionBuilderSetIsAggregate(builder, true); TF_OpDefinitionBuilderSetAllowsUninitializedInput(builder, true); @@ -158,8 +138,13 @@ TEST(OpsTest, AttributeAccessors) { ASSERT_EQ(4, op.deprecation().version()); ASSERT_EQ(deprecation_msg, op.deprecation().explanation()); ASSERT_EQ(2, op.attr_size()); - ASSERT_EQ("int", op.attr(0).type()); - ASSERT_EQ(2, op.attr(0).minimum()); + ASSERT_EQ("list(float)", op.attr(0).type()); + AttrValue::ListValue l = op.attr(0).default_value().list(); + ASSERT_EQ(1, l.f(0)); + ASSERT_EQ(2, l.f(1)); + ASSERT_EQ(3, l.f(2)); + ASSERT_EQ(4, l.f(3)); + ASSERT_EQ("string", op.attr(1).type()); ASSERT_EQ("my string", op.attr(1).default_value().s()); found = true; @@ -170,151 +155,5 @@ TEST(OpsTest, AttributeAccessors) { TF_DeleteBuffer(op_list_buffer); } -#define C_CTX(x) reinterpret_cast(x) -#define C_SHP(x) reinterpret_cast(x) - -static OpDef MakeOpDef(int num_inputs, int num_outputs) { - OpRegistrationData op_reg_data; - OpDefBuilder b("dummy"); - for (int i = 0; i < num_inputs; ++i) { - b.Input(strings::StrCat("i", i, ": float")); - } - for (int i = 0; i < num_outputs; ++i) { - b.Output(strings::StrCat("o", i, ": float")); - } - CHECK(b.Attr("foo:string").Finalize(&op_reg_data).ok()); - return op_reg_data.op_def; -} - -// Tests for shape inference - -PartialTensorShape S(std::initializer_list dims) { - return PartialTensorShape(dims); -} - -PartialTensorShape Unknown() { return PartialTensorShape(); } - -TEST(OpsTest, ShapeInferenceWithRank) { - NodeDef def; - shape_inference::InferenceContext c(0, &def, MakeOpDef(1, 0), - {S({10, 20, 30})}, {}, {}, {}); - - shape_inference::ShapeHandle in0 = c.input(0); - shape_inference::ShapeHandle s1; - - TF_Status* status = TF_NewStatus(); - TF_ShapeInferenceContextWithRankAtMost(C_CTX(&c), C_SHP(&in0), 3, C_SHP(&s1), - status); - EXPECT_EQ("[10,20,30]", c.DebugString(s1)); - EXPECT_EQ(TF_OK, TF_GetCode(status)); - - TF_ShapeInferenceContextWithRankAtLeast(C_CTX(&c), C_SHP(&in0), 3, C_SHP(&s1), - status); - EXPECT_EQ("[10,20,30]", c.DebugString(s1)); - EXPECT_EQ(TF_OK, TF_GetCode(status)); - - TF_ShapeInferenceContextWithRankAtLeast(C_CTX(&c), C_SHP(&in0), 6, C_SHP(&s1), - status); - ASSERT_NE(TF_OK, TF_GetCode(status)); - - TF_SetStatus(status, TF_OK, ""); - TF_ShapeInferenceContextWithRankAtMost(C_CTX(&c), C_SHP(&in0), 1, C_SHP(&s1), - status); - ASSERT_NE(TF_OK, TF_GetCode(status)); - - TF_SetStatus(status, TF_OK, ""); - TF_ShapeInferenceContextWithRank(C_CTX(&c), C_SHP(&in0), 3, C_SHP(&s1), - status); - ASSERT_EQ(TF_OK, TF_GetCode(status)); - - TF_ShapeInferenceContextWithRank(C_CTX(&c), C_SHP(&in0), 4, C_SHP(&s1), - status); - ASSERT_NE(TF_OK, TF_GetCode(status)); - - TF_DeleteStatus(status); -} - -TEST(OpsTest, ShapeInferenceWithRank_UnknownRank) { - NodeDef def; - shape_inference::InferenceContext c(0, &def, MakeOpDef(2, 2), - {Unknown(), S({1, -1, 3})}, {}, {}, {}); - - shape_inference::ShapeHandle in0 = c.input(0); - shape_inference::ShapeHandle s1; - - // WithRankAtMost and WithRankAtLeast on a shape with unknown dimensionality - // always succeed. - TF_Status* status = TF_NewStatus(); - TF_ShapeInferenceContextWithRankAtMost(C_CTX(&c), C_SHP(&in0), 1, C_SHP(&s1), - status); - EXPECT_EQ("?", c.DebugString(s1)); - EXPECT_EQ(TF_OK, TF_GetCode(status)); - - TF_ShapeInferenceContextWithRankAtLeast(C_CTX(&c), C_SHP(&in0), 1, C_SHP(&s1), - status); - EXPECT_EQ("?", c.DebugString(s1)); - EXPECT_EQ(TF_OK, TF_GetCode(status)); - - TF_DeleteStatus(status); -} - -TEST(OpsTest, ShapeInferenceConcatenateShapes) { - NodeDef def; - shape_inference::InferenceContext c(0, &def, MakeOpDef(2, 0), - {S({1, 2}), S({3, 4})}, {}, {}, {}); - ASSERT_EQ(2, TF_ShapeInferenceContextNumInputs(C_CTX(&c))); - shape_inference::ShapeHandle a = c.input(0); - shape_inference::ShapeHandle b = c.input(1); - TF_ShapeHandle* result = TF_NewShapeHandle(); - TF_Status* status = TF_NewStatus(); - TF_ShapeInferenceContextConcatenateShapes(C_CTX(&c), C_SHP(&a), C_SHP(&b), - result, status); - EXPECT_EQ( - "[1,2,3,4]", - c.DebugString(*reinterpret_cast(result))); - EXPECT_EQ(TF_OK, TF_GetCode(status)); - TF_DeleteShapeHandle(result); - TF_DeleteStatus(status); -} - -TEST(OpsTest, DimensionHandleValueKnown) { - NodeDef def; - shape_inference::InferenceContext c(0, &def, MakeOpDef(2, 0), - {S({1, 2}), S({3, 4})}, {}, {}, {}); - TF_ShapeHandle* handle = - TF_ShapeInferenceContextVectorFromSize(C_CTX(&c), 43); - ASSERT_EQ( - "[43]", - c.DebugString(*reinterpret_cast(handle))); - ASSERT_EQ(1, TF_ShapeInferenceContextRankKnown(C_CTX(&c), handle)); - ASSERT_EQ(1, TF_ShapeInferenceContextRank(C_CTX(&c), handle)); - - TF_DimensionHandle* dim_handle = TF_NewDimensionHandle(); - TF_ShapeInferenceContextDim(C_CTX(&c), handle, 0, dim_handle); - ASSERT_EQ(1, TF_DimensionHandleValueKnown(dim_handle)); - ASSERT_EQ(43, TF_DimensionHandleValue(dim_handle)); - TF_DeleteShapeHandle(handle); - TF_DeleteDimensionHandle(dim_handle); -} - -TEST(OpsTest, ShapeInferenceSubshape) { - NodeDef def; - shape_inference::InferenceContext c(0, &def, MakeOpDef(1, 0), - {S({10, 20, 30, 40, 50})}, {}, {}, {}); - ASSERT_EQ("[10,20,30,40,50]", c.DebugString(c.input(0))); - - TF_ShapeHandle* handle = TF_NewShapeHandle(); - TF_Status* status = TF_NewStatus(); - TF_ShapeInferenceContextGetInput(C_CTX(&c), 0, handle, status); - ASSERT_EQ(TF_OK, TF_GetCode(status)); - TF_ShapeInferenceContextSubshape(C_CTX(&c), handle, 1, -1, handle, status); - ASSERT_EQ(TF_OK, TF_GetCode(status)); - ASSERT_EQ( - "[20,30,40]", - c.DebugString(*reinterpret_cast(handle))); - TF_DeleteStatus(status); - TF_DeleteShapeHandle(handle); -} - } // namespace } // namespace tensorflow