Automated rollback of commit fa636543de3159b35e304f7db670d13df8af9878
PiperOrigin-RevId: 250567146
This commit is contained in:
parent
0facf5a30f
commit
b25e77ca06
@ -221,6 +221,7 @@ tf_cuda_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":c_api_internal",
|
||||
":c_api_no_xla",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
@ -527,13 +528,10 @@ tf_cc_test(
|
||||
deps = [
|
||||
":c_api",
|
||||
":ops",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -23,7 +23,6 @@ limitations under the License.
|
||||
|
||||
using ::tensorflow::DataType;
|
||||
using ::tensorflow::OpDef;
|
||||
using ::tensorflow::OpDefBuilder;
|
||||
using ::tensorflow::OpDeprecation;
|
||||
using ::tensorflow::OpShapeInferenceFn;
|
||||
using ::tensorflow::Set_TF_Status_from_Status;
|
||||
@ -32,53 +31,111 @@ using ::tensorflow::shape_inference::DimensionHandle;
|
||||
using ::tensorflow::shape_inference::InferenceContext;
|
||||
using ::tensorflow::shape_inference::ShapeHandle;
|
||||
|
||||
typedef struct TF_OpDefinitionBuilder {
|
||||
// The op definition proto representing the op.
|
||||
tensorflow::OpDef op_def;
|
||||
|
||||
// The shape inference function, or nullptr if none is provided for this op.
|
||||
OpShapeInferenceFn shape_inference_func;
|
||||
} TF_OpDefinitionBuilder;
|
||||
|
||||
TF_OpDefinitionBuilder* TF_NewOpDefinitionBuilder(const char* op_name) {
|
||||
auto* result = new OpDefBuilder(op_name);
|
||||
return reinterpret_cast<TF_OpDefinitionBuilder*>(result);
|
||||
auto* result = new TF_OpDefinitionBuilder;
|
||||
result->op_def.set_name(op_name);
|
||||
return result;
|
||||
}
|
||||
|
||||
void TF_DeleteOpDefinitionBuilder(TF_OpDefinitionBuilder* builder) {
|
||||
delete reinterpret_cast<OpDefBuilder*>(builder);
|
||||
delete builder;
|
||||
}
|
||||
|
||||
static void PopulateArg(OpDef::ArgDef* arg, const char* name,
|
||||
TF_DataType type) {
|
||||
arg->set_name(name);
|
||||
arg->set_type(static_cast<DataType>(type));
|
||||
}
|
||||
|
||||
void TF_OpDefinitionBuilderAddInput(TF_OpDefinitionBuilder* builder,
|
||||
const char* input_spec) {
|
||||
reinterpret_cast<OpDefBuilder*>(builder)->Input(input_spec);
|
||||
const char* name, TF_DataType type) {
|
||||
PopulateArg(builder->op_def.add_input_arg(), name, type);
|
||||
}
|
||||
|
||||
void TF_OpDefinitionBuilderAddOutput(TF_OpDefinitionBuilder* builder,
|
||||
const char* output_spec) {
|
||||
reinterpret_cast<OpDefBuilder*>(builder)->Output(output_spec);
|
||||
const char* name, TF_DataType type) {
|
||||
PopulateArg(builder->op_def.add_output_arg(), name, type);
|
||||
}
|
||||
|
||||
#define DEFINE_BUILDER_BOOL_SETTER(func_name) \
|
||||
void TF_OpDefinitionBuilder##func_name(TF_OpDefinitionBuilder* builder, \
|
||||
bool arg_name) { \
|
||||
reinterpret_cast<OpDefBuilder*>(builder)->func_name(); \
|
||||
#define DEFINE_BUILDER_BOOL_SETTER(func_name, builder_setter_name, arg_name) \
|
||||
void TF_OpDefinitionBuilder##func_name(TF_OpDefinitionBuilder* builder, \
|
||||
bool arg_name) { \
|
||||
builder->op_def.builder_setter_name(arg_name); \
|
||||
}
|
||||
|
||||
DEFINE_BUILDER_BOOL_SETTER(SetIsCommutative)
|
||||
DEFINE_BUILDER_BOOL_SETTER(SetIsAggregate)
|
||||
DEFINE_BUILDER_BOOL_SETTER(SetIsStateful)
|
||||
DEFINE_BUILDER_BOOL_SETTER(SetAllowsUninitializedInput)
|
||||
DEFINE_BUILDER_BOOL_SETTER(SetIsCommutative, set_is_commutative, is_commutative)
|
||||
DEFINE_BUILDER_BOOL_SETTER(SetIsAggregate, set_is_aggregate, is_aggregate)
|
||||
DEFINE_BUILDER_BOOL_SETTER(SetIsStateful, set_is_stateful, is_stateful)
|
||||
DEFINE_BUILDER_BOOL_SETTER(SetAllowsUninitializedInput,
|
||||
set_allows_uninitialized_input,
|
||||
allows_unintialized_input)
|
||||
|
||||
void TF_OpDefinitionBuilderAddAttr(TF_OpDefinitionBuilder* builder,
|
||||
const char* attr_spec) {
|
||||
reinterpret_cast<OpDefBuilder*>(builder)->Attr(attr_spec);
|
||||
static OpDef::AttrDef* AddAttribute(TF_OpDefinitionBuilder* builder,
|
||||
const char* name, const char* type_name) {
|
||||
OpDef::AttrDef* attr = builder->op_def.add_attr();
|
||||
attr->set_name(name);
|
||||
attr->set_type(type_name);
|
||||
return attr;
|
||||
}
|
||||
|
||||
#define DEFINE_ATTR_SETTER(attr_type, type_name, field_c_type, field_name) \
|
||||
void TF_OpDefinitionBuilderAdd##attr_type##Attr( \
|
||||
TF_OpDefinitionBuilder* builder, const char* name) { \
|
||||
AddAttribute(builder, name, type_name); \
|
||||
} \
|
||||
\
|
||||
void TF_OpDefinitionBuilderAdd##attr_type##AttrWithDefaultValue( \
|
||||
TF_OpDefinitionBuilder* builder, const char* name, \
|
||||
field_c_type field_name) { \
|
||||
OpDef::AttrDef* attr = AddAttribute(builder, name, type_name); \
|
||||
attr->mutable_default_value()->set_##field_name(field_name); \
|
||||
} \
|
||||
\
|
||||
void TF_OpDefinitionBuilderAdd##attr_type##ListAttrWithDefaultValues( \
|
||||
TF_OpDefinitionBuilder* builder, const char* name, \
|
||||
field_c_type field_name[], size_t n) { \
|
||||
OpDef::AttrDef* attr = AddAttribute(builder, name, "list(" type_name ")"); \
|
||||
for (int _i = 0; _i < n; ++_i) { \
|
||||
attr->mutable_default_value()->mutable_list()->add_##field_name( \
|
||||
field_name[_i]); \
|
||||
} \
|
||||
} \
|
||||
\
|
||||
void TF_OpDefinitionBuilderAdd##attr_type##ListAttr( \
|
||||
TF_OpDefinitionBuilder* builder, const char* name) { \
|
||||
TF_OpDefinitionBuilderAdd##attr_type##ListAttrWithDefaultValues( \
|
||||
builder, name, NULL, 0); \
|
||||
}
|
||||
|
||||
DEFINE_ATTR_SETTER(String, "string", const char*, s)
|
||||
DEFINE_ATTR_SETTER(Int, "int", int64_t, i)
|
||||
DEFINE_ATTR_SETTER(Float, "float", float, f)
|
||||
DEFINE_ATTR_SETTER(Bool, "bool", bool, b)
|
||||
|
||||
void TF_OpDefinitionBuilderDeprecated(TF_OpDefinitionBuilder* builder,
|
||||
int version, const char* explanation) {
|
||||
reinterpret_cast<OpDefBuilder*>(builder)->Deprecated(version, explanation);
|
||||
OpDeprecation* dep = builder->op_def.mutable_deprecation();
|
||||
dep->set_version(version);
|
||||
dep->set_explanation(explanation);
|
||||
}
|
||||
|
||||
void TF_RegisterOpDefinition(TF_OpDefinitionBuilder* builder,
|
||||
TF_Status* status) {
|
||||
auto* cc_builder = reinterpret_cast<OpDefBuilder*>(builder);
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
::tensorflow::OpRegistry::Global()->Register(
|
||||
[cc_builder](::tensorflow::OpRegistrationData* op_reg_data) -> Status {
|
||||
return cc_builder->Finalize(op_reg_data);
|
||||
[builder](::tensorflow::OpRegistrationData* op_reg_data) -> Status {
|
||||
op_reg_data->op_def.Clear();
|
||||
op_reg_data->op_def.MergeFrom(builder->op_def);
|
||||
op_reg_data->shape_inference_fn = builder->shape_inference_func;
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
// Calling ProcessRegistrations ensures that the cc_builder's finalize method
|
||||
@ -86,23 +143,22 @@ void TF_RegisterOpDefinition(TF_OpDefinitionBuilder* builder,
|
||||
Set_TF_Status_from_Status(
|
||||
status, ::tensorflow::OpRegistry::Global()->ProcessRegistrations());
|
||||
|
||||
delete cc_builder;
|
||||
delete builder;
|
||||
}
|
||||
|
||||
void TF_OpDefinitionBuilderSetShapeInferenceFunction(
|
||||
TF_OpDefinitionBuilder* builder,
|
||||
void (*shape_inference_func)(TF_ShapeInferenceContext* ctx,
|
||||
TF_Status* status)) {
|
||||
auto* cc_builder = reinterpret_cast<OpDefBuilder*>(builder);
|
||||
cc_builder->SetShapeFn(
|
||||
builder->shape_inference_func =
|
||||
[shape_inference_func](InferenceContext* ctx) -> tensorflow::Status {
|
||||
TF_Status* c_status = TF_NewStatus();
|
||||
auto c_ctx = reinterpret_cast<TF_ShapeInferenceContext*>(ctx);
|
||||
shape_inference_func(c_ctx, c_status);
|
||||
tensorflow::Status result = ::tensorflow::StatusFromTF_Status(c_status);
|
||||
TF_DeleteStatus(c_status);
|
||||
return result;
|
||||
});
|
||||
TF_Status* c_status = TF_NewStatus();
|
||||
auto c_ctx = reinterpret_cast<TF_ShapeInferenceContext*>(ctx);
|
||||
shape_inference_func(c_ctx, c_status);
|
||||
tensorflow::Status result = ::tensorflow::StatusFromTF_Status(c_status);
|
||||
TF_DeleteStatus(c_status);
|
||||
return result;
|
||||
};
|
||||
}
|
||||
|
||||
TF_ShapeHandle* TF_NewShapeHandle() {
|
||||
|
@ -125,65 +125,97 @@ TF_CAPI_EXPORT extern void TF_DeleteOpDefinitionBuilder(
|
||||
//----------------------------------------------------
|
||||
// Attribute functions.
|
||||
|
||||
// Adds an attr to the given TF_OpDefinitionBuilder. The spec has
|
||||
// format "<name>:<type>" or "<name>:<type>=<default>"
|
||||
// where <name> matches regexp [a-zA-Z][a-zA-Z0-9_]*.
|
||||
// By convention, names containing only capital letters are reserved for
|
||||
// attributes whose values can be inferred by the operator implementation if not
|
||||
// supplied by the user. If the attribute name contains characters other than
|
||||
// capital letters, the operator expects the user to provide the attribute value
|
||||
// at operation runtime.
|
||||
//
|
||||
// <type> can be:
|
||||
// "string", "int", "float", "bool", "type", "shape", or "tensor"
|
||||
// "numbertype", "realnumbertype", "quantizedtype"
|
||||
// (meaning "type" with a restriction on valid values)
|
||||
// "{int32,int64}" or {realnumbertype,quantizedtype,string}"
|
||||
// (meaning "type" with a restriction containing unions of value types)
|
||||
// "{\"foo\", \"bar\n baz\"}", or "{'foo', 'bar\n baz'}"
|
||||
// (meaning "string" with a restriction on valid values)
|
||||
// "list(string)", ..., "list(tensor)", "list(numbertype)", ...
|
||||
// (meaning lists of the above types)
|
||||
// "int >= 2" (meaning "int" with a restriction on valid values)
|
||||
// "list(string) >= 2", "list(int) >= 2"
|
||||
// (meaning "list(string)" / "list(int)" with length at least 2)
|
||||
// <default>, if included, should use the Proto text format
|
||||
// of <type>. For lists use [a, b, c] format.
|
||||
//
|
||||
// Note that any attr specifying the length of an input or output will
|
||||
// get a default minimum of 1 unless the >= # syntax is used.
|
||||
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddAttr(
|
||||
TF_OpDefinitionBuilder* builder, const char* attr_spec);
|
||||
// Adds a string attribute with the given name to the builder.
|
||||
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddStringAttr(
|
||||
TF_OpDefinitionBuilder* builder, const char* name);
|
||||
|
||||
// Adds an input to this TF_OpDefinitionBuilder.
|
||||
// The spec has form "<name>:<type-expr>" or "<name>:Ref(<type-expr>)"
|
||||
// where <name> matches regexp [a-z][a-z0-9_]* and <type-expr> can be:
|
||||
// * For a single tensor: <type>
|
||||
// * For a sequence of tensors with the same type: <number>*<type>
|
||||
// * For a sequence of tensors with different types: <type-list>
|
||||
// Where:
|
||||
// <type> is either one of "float", "int32", "string", ...
|
||||
// or the name of an attr (see TF_OpDefinitionBuilderAddAttr)
|
||||
// with type "type".
|
||||
// <number> is the name of an attr with type "int".
|
||||
// <type-list> is the name of an attr with type "list(type)".
|
||||
// Adds a string attribute with the given name and default value to the builder.
|
||||
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddStringAttrWithDefaultValue(
|
||||
TF_OpDefinitionBuilder* builder, const char* name, const char* value);
|
||||
|
||||
// Adds a string list attribute with the given name and no default value to the
|
||||
// builder.
|
||||
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddStringListAttr(
|
||||
TF_OpDefinitionBuilder* builder, const char* name);
|
||||
|
||||
// Adds a string list attribute with the given default values to the builder.
|
||||
// `values` must contain at least `n` elements.
|
||||
TF_CAPI_EXPORT extern void
|
||||
TF_OpDefinitionBuilderAddStringListAttrWithDefaultValues(
|
||||
TF_OpDefinitionBuilder* builder, const char* name, const char* values[],
|
||||
size_t n);
|
||||
|
||||
// Adds an integer attribute with the given name and no default value to the
|
||||
// builder.
|
||||
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddIntAttr(
|
||||
TF_OpDefinitionBuilder* builder, const char* name);
|
||||
|
||||
// Adds an integer attribute with the given name and default value to the
|
||||
// builder.
|
||||
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddIntAttrWithDefaultValue(
|
||||
TF_OpDefinitionBuilder* builder, const char* name, int64_t value);
|
||||
|
||||
// Adds an integer list attribute with the given name and no default value to
|
||||
// the builder.
|
||||
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddIntListAttr(
|
||||
TF_OpDefinitionBuilder* builder, const char* name);
|
||||
|
||||
// Adds an integer list attribute with the given name and default values to the
|
||||
// builder. `values` must contain at least `n` elements.
|
||||
TF_CAPI_EXPORT extern void
|
||||
TF_OpDefinitionBuilderAddIntListAttrWithDefaultValues(
|
||||
TF_OpDefinitionBuilder* builder, const char* name, int64_t values[],
|
||||
size_t n);
|
||||
|
||||
// Adds a float attribute with the given name and no default value to the
|
||||
// builder.
|
||||
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddFloatAttr(
|
||||
TF_OpDefinitionBuilder* builder, const char* name);
|
||||
|
||||
// Adds a float attribute with the given name and default value to the builder.
|
||||
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddFloatAttrWithDefaultValue(
|
||||
TF_OpDefinitionBuilder* builder, const char* name, float value);
|
||||
|
||||
// Adds a float list attribute with the given name and no default value to the
|
||||
// builder.
|
||||
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddFloatListAttr(
|
||||
TF_OpDefinitionBuilder* builder, const char* name);
|
||||
|
||||
// Adds a float list attribute with the given name and default values to the
|
||||
// builder. `values` must contain at least `n` elements.
|
||||
TF_CAPI_EXPORT extern void
|
||||
TF_OpDefinitionBuilderAddFloatListAttrWithDefaultValues(
|
||||
TF_OpDefinitionBuilder* builder, const char* name, float values[],
|
||||
size_t n);
|
||||
|
||||
// Adds a boolean attribute with the given name and no default value to the
|
||||
// builder.
|
||||
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddBoolAttr(
|
||||
TF_OpDefinitionBuilder* builder, const char* name);
|
||||
|
||||
// Adds a boolean attribute with the given name and default value to the
|
||||
// builder.
|
||||
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddBoolAttrWithDefaultValue(
|
||||
TF_OpDefinitionBuilder* builder, const char* name, bool value);
|
||||
|
||||
// Adds a boolean list attribute with the given name and no default value to the
|
||||
// builder.
|
||||
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddBoolListAttr(
|
||||
TF_OpDefinitionBuilder* builder, const char* name);
|
||||
|
||||
// Adds a boolean list attribute with the given name and default values to the
|
||||
// builder. `values` must contain at least `n` elements.
|
||||
TF_CAPI_EXPORT extern void
|
||||
TF_OpDefinitionBuilderAddBoolListAttrWithDefaultValues(
|
||||
TF_OpDefinitionBuilder* builder, const char* name, bool values[], size_t n);
|
||||
|
||||
// Adds the input with the given name and type to the op.
|
||||
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddInput(
|
||||
TF_OpDefinitionBuilder* builder, const char* input_spec);
|
||||
TF_OpDefinitionBuilder* builder, const char* name, TF_DataType type);
|
||||
|
||||
// Adds an output to this TF_OpDefinitionBuilder.
|
||||
// The spec has form "<name>:<type-expr>" or "<name>:Ref(<type-expr>)"
|
||||
// where <name> matches regexp [a-z][a-z0-9_]* and <type-expr> can be:
|
||||
// * For a single tensor: <type>
|
||||
// * For a sequence of tensors with the same type: <number>*<type>
|
||||
// * For a sequence of tensors with different types: <type-list>
|
||||
// Where:
|
||||
// <type> is either one of "float", "int32", "string", ...
|
||||
// or the name of an attr (see TF_OpDefinitionBuilderAddAttr)
|
||||
// with type "type".
|
||||
// <number> is the name of an attr with type "int".
|
||||
// <type-list> is the name of an attr with type "list(type)".
|
||||
// Adds the output with the given name and type to the op.
|
||||
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderAddOutput(
|
||||
TF_OpDefinitionBuilder* builder, const char* output_spec);
|
||||
TF_OpDefinitionBuilder* builder, const char* output, TF_DataType type);
|
||||
|
||||
// Sets the commutative property for the op built by the given builder.
|
||||
TF_CAPI_EXPORT extern void TF_OpDefinitionBuilderSetIsCommutative(
|
||||
@ -331,6 +363,10 @@ TF_CAPI_EXPORT extern void TF_ShapeInferenceContextDim(
|
||||
TF_ShapeInferenceContext* ctx, TF_ShapeHandle* shape_handle, int64_t i,
|
||||
TF_DimensionHandle* result);
|
||||
|
||||
// Returns 1 if the given handle represents a known dimension.
|
||||
TF_CAPI_EXPORT extern int TF_ShapeInferenceContextDimValueKnown(
|
||||
TF_ShapeInferenceContext* ctx, TF_DimensionHandle* handle);
|
||||
|
||||
// Returns in <*result> a sub-shape of <shape_handle>, with dimensions
|
||||
// [start:end]. <start> and <end> can be negative, to index from the end of the
|
||||
// shape. <start> and <end> are set to the rank of <shape_handle> if > rank of
|
||||
|
@ -15,14 +15,10 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/c/ops.h"
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/fake_input.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/framework/op_def_builder.h"
|
||||
#include "tensorflow/core/framework/shape_inference_testutil.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
@ -32,10 +28,10 @@ namespace {
|
||||
|
||||
TEST(OpsTest, TestBasicOpRegistration) {
|
||||
TF_OpDefinitionBuilder* builder = TF_NewOpDefinitionBuilder("SomeOp");
|
||||
TF_OpDefinitionBuilderAddAttr(builder, "attr1: string");
|
||||
TF_OpDefinitionBuilderAddInput(builder, "input1: uint8");
|
||||
TF_OpDefinitionBuilderAddInput(builder, "input2: uint16");
|
||||
TF_OpDefinitionBuilderAddOutput(builder, "output1: uint32");
|
||||
TF_OpDefinitionBuilderAddStringAttr(builder, "attr1");
|
||||
TF_OpDefinitionBuilderAddInput(builder, "input1", TF_UINT8);
|
||||
TF_OpDefinitionBuilderAddInput(builder, "input2", TF_UINT16);
|
||||
TF_OpDefinitionBuilderAddOutput(builder, "output1", TF_UINT32);
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_RegisterOpDefinition(builder, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
@ -70,8 +66,8 @@ TEST(OpsTest, TestShapeInference_IdentityFunction) {
|
||||
ShapeInferenceTestOp op("SomeTestOp");
|
||||
|
||||
TF_OpDefinitionBuilder* builder = TF_NewOpDefinitionBuilder("SomeTestOp");
|
||||
TF_OpDefinitionBuilderAddInput(builder, "input1: uint8");
|
||||
TF_OpDefinitionBuilderAddOutput(builder, "output1: uint8");
|
||||
TF_OpDefinitionBuilderAddInput(builder, "input1", TF_UINT8);
|
||||
TF_OpDefinitionBuilderAddOutput(builder, "output1", TF_UINT8);
|
||||
TF_OpDefinitionBuilderSetShapeInferenceFunction(builder, &identity_shape_fn);
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_RegisterOpDefinition(builder, status);
|
||||
@ -82,25 +78,6 @@ TEST(OpsTest, TestShapeInference_IdentityFunction) {
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
TEST(OpsTest, TestShapeInference_UnknownShape) {
|
||||
ShapeInferenceTestOp op("UnknownShapeOp");
|
||||
|
||||
TF_OpDefinitionBuilder* builder = TF_NewOpDefinitionBuilder("UnknownShapeOp");
|
||||
TF_OpDefinitionBuilderAddInput(builder, "input1: uint8");
|
||||
TF_OpDefinitionBuilderAddInput(builder, "input2: uint32");
|
||||
TF_OpDefinitionBuilderAddOutput(builder, "output1: uint8");
|
||||
TF_OpDefinitionBuilderAddOutput(builder, "output2: uint8");
|
||||
TF_OpDefinitionBuilderSetShapeInferenceFunction(
|
||||
builder, &TF_ShapeInferenceContextSetUnknownShape);
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_RegisterOpDefinition(builder, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TF_ASSERT_OK(shape_inference::ShapeInferenceTestutil::InferShapes(
|
||||
op, "[1,2];[3,4]", "?;?"));
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
// Creates an output whose shape is a vector of length
|
||||
// TF_ShapeInferenceContextRank.
|
||||
void vectorize_shape_fn(TF_ShapeInferenceContext* ctx, TF_Status* status) {
|
||||
@ -119,8 +96,8 @@ TEST(OpsTest, TestShapeInference_VectorizeFunction) {
|
||||
|
||||
TF_OpDefinitionBuilder* builder =
|
||||
TF_NewOpDefinitionBuilder("VectorizeTestOp");
|
||||
TF_OpDefinitionBuilderAddInput(builder, "input1: uint8");
|
||||
TF_OpDefinitionBuilderAddOutput(builder, "output1: uint8");
|
||||
TF_OpDefinitionBuilderAddInput(builder, "input1", TF_UINT8);
|
||||
TF_OpDefinitionBuilderAddOutput(builder, "output1", TF_UINT8);
|
||||
TF_OpDefinitionBuilderSetShapeInferenceFunction(builder, &vectorize_shape_fn);
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_RegisterOpDefinition(builder, status);
|
||||
@ -134,8 +111,11 @@ TEST(OpsTest, TestShapeInference_VectorizeFunction) {
|
||||
TEST(OpsTest, AttributeAccessors) {
|
||||
TF_OpDefinitionBuilder* builder =
|
||||
TF_NewOpDefinitionBuilder("AttributeAccesorsOp");
|
||||
TF_OpDefinitionBuilderAddAttr(builder, "foo1: int >= 2");
|
||||
TF_OpDefinitionBuilderAddAttr(builder, "foo2: string=\"my string\"");
|
||||
float values[] = {1, 2, 3, 4};
|
||||
TF_OpDefinitionBuilderAddFloatListAttrWithDefaultValues(
|
||||
builder, "foo1", values, sizeof(values));
|
||||
TF_OpDefinitionBuilderAddStringAttrWithDefaultValue(builder, "foo2",
|
||||
"my string");
|
||||
TF_OpDefinitionBuilderSetIsCommutative(builder, true);
|
||||
TF_OpDefinitionBuilderSetIsAggregate(builder, true);
|
||||
TF_OpDefinitionBuilderSetAllowsUninitializedInput(builder, true);
|
||||
@ -158,8 +138,13 @@ TEST(OpsTest, AttributeAccessors) {
|
||||
ASSERT_EQ(4, op.deprecation().version());
|
||||
ASSERT_EQ(deprecation_msg, op.deprecation().explanation());
|
||||
ASSERT_EQ(2, op.attr_size());
|
||||
ASSERT_EQ("int", op.attr(0).type());
|
||||
ASSERT_EQ(2, op.attr(0).minimum());
|
||||
ASSERT_EQ("list(float)", op.attr(0).type());
|
||||
AttrValue::ListValue l = op.attr(0).default_value().list();
|
||||
ASSERT_EQ(1, l.f(0));
|
||||
ASSERT_EQ(2, l.f(1));
|
||||
ASSERT_EQ(3, l.f(2));
|
||||
ASSERT_EQ(4, l.f(3));
|
||||
|
||||
ASSERT_EQ("string", op.attr(1).type());
|
||||
ASSERT_EQ("my string", op.attr(1).default_value().s());
|
||||
found = true;
|
||||
@ -170,151 +155,5 @@ TEST(OpsTest, AttributeAccessors) {
|
||||
TF_DeleteBuffer(op_list_buffer);
|
||||
}
|
||||
|
||||
#define C_CTX(x) reinterpret_cast<TF_ShapeInferenceContext*>(x)
|
||||
#define C_SHP(x) reinterpret_cast<TF_ShapeHandle*>(x)
|
||||
|
||||
static OpDef MakeOpDef(int num_inputs, int num_outputs) {
|
||||
OpRegistrationData op_reg_data;
|
||||
OpDefBuilder b("dummy");
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
b.Input(strings::StrCat("i", i, ": float"));
|
||||
}
|
||||
for (int i = 0; i < num_outputs; ++i) {
|
||||
b.Output(strings::StrCat("o", i, ": float"));
|
||||
}
|
||||
CHECK(b.Attr("foo:string").Finalize(&op_reg_data).ok());
|
||||
return op_reg_data.op_def;
|
||||
}
|
||||
|
||||
// Tests for shape inference
|
||||
|
||||
PartialTensorShape S(std::initializer_list<int64> dims) {
|
||||
return PartialTensorShape(dims);
|
||||
}
|
||||
|
||||
PartialTensorShape Unknown() { return PartialTensorShape(); }
|
||||
|
||||
TEST(OpsTest, ShapeInferenceWithRank) {
|
||||
NodeDef def;
|
||||
shape_inference::InferenceContext c(0, &def, MakeOpDef(1, 0),
|
||||
{S({10, 20, 30})}, {}, {}, {});
|
||||
|
||||
shape_inference::ShapeHandle in0 = c.input(0);
|
||||
shape_inference::ShapeHandle s1;
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_ShapeInferenceContextWithRankAtMost(C_CTX(&c), C_SHP(&in0), 3, C_SHP(&s1),
|
||||
status);
|
||||
EXPECT_EQ("[10,20,30]", c.DebugString(s1));
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status));
|
||||
|
||||
TF_ShapeInferenceContextWithRankAtLeast(C_CTX(&c), C_SHP(&in0), 3, C_SHP(&s1),
|
||||
status);
|
||||
EXPECT_EQ("[10,20,30]", c.DebugString(s1));
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status));
|
||||
|
||||
TF_ShapeInferenceContextWithRankAtLeast(C_CTX(&c), C_SHP(&in0), 6, C_SHP(&s1),
|
||||
status);
|
||||
ASSERT_NE(TF_OK, TF_GetCode(status));
|
||||
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
TF_ShapeInferenceContextWithRankAtMost(C_CTX(&c), C_SHP(&in0), 1, C_SHP(&s1),
|
||||
status);
|
||||
ASSERT_NE(TF_OK, TF_GetCode(status));
|
||||
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
TF_ShapeInferenceContextWithRank(C_CTX(&c), C_SHP(&in0), 3, C_SHP(&s1),
|
||||
status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status));
|
||||
|
||||
TF_ShapeInferenceContextWithRank(C_CTX(&c), C_SHP(&in0), 4, C_SHP(&s1),
|
||||
status);
|
||||
ASSERT_NE(TF_OK, TF_GetCode(status));
|
||||
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
TEST(OpsTest, ShapeInferenceWithRank_UnknownRank) {
|
||||
NodeDef def;
|
||||
shape_inference::InferenceContext c(0, &def, MakeOpDef(2, 2),
|
||||
{Unknown(), S({1, -1, 3})}, {}, {}, {});
|
||||
|
||||
shape_inference::ShapeHandle in0 = c.input(0);
|
||||
shape_inference::ShapeHandle s1;
|
||||
|
||||
// WithRankAtMost and WithRankAtLeast on a shape with unknown dimensionality
|
||||
// always succeed.
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_ShapeInferenceContextWithRankAtMost(C_CTX(&c), C_SHP(&in0), 1, C_SHP(&s1),
|
||||
status);
|
||||
EXPECT_EQ("?", c.DebugString(s1));
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status));
|
||||
|
||||
TF_ShapeInferenceContextWithRankAtLeast(C_CTX(&c), C_SHP(&in0), 1, C_SHP(&s1),
|
||||
status);
|
||||
EXPECT_EQ("?", c.DebugString(s1));
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status));
|
||||
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
TEST(OpsTest, ShapeInferenceConcatenateShapes) {
|
||||
NodeDef def;
|
||||
shape_inference::InferenceContext c(0, &def, MakeOpDef(2, 0),
|
||||
{S({1, 2}), S({3, 4})}, {}, {}, {});
|
||||
ASSERT_EQ(2, TF_ShapeInferenceContextNumInputs(C_CTX(&c)));
|
||||
shape_inference::ShapeHandle a = c.input(0);
|
||||
shape_inference::ShapeHandle b = c.input(1);
|
||||
TF_ShapeHandle* result = TF_NewShapeHandle();
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_ShapeInferenceContextConcatenateShapes(C_CTX(&c), C_SHP(&a), C_SHP(&b),
|
||||
result, status);
|
||||
EXPECT_EQ(
|
||||
"[1,2,3,4]",
|
||||
c.DebugString(*reinterpret_cast<shape_inference::ShapeHandle*>(result)));
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status));
|
||||
TF_DeleteShapeHandle(result);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
TEST(OpsTest, DimensionHandleValueKnown) {
|
||||
NodeDef def;
|
||||
shape_inference::InferenceContext c(0, &def, MakeOpDef(2, 0),
|
||||
{S({1, 2}), S({3, 4})}, {}, {}, {});
|
||||
TF_ShapeHandle* handle =
|
||||
TF_ShapeInferenceContextVectorFromSize(C_CTX(&c), 43);
|
||||
ASSERT_EQ(
|
||||
"[43]",
|
||||
c.DebugString(*reinterpret_cast<shape_inference::ShapeHandle*>(handle)));
|
||||
ASSERT_EQ(1, TF_ShapeInferenceContextRankKnown(C_CTX(&c), handle));
|
||||
ASSERT_EQ(1, TF_ShapeInferenceContextRank(C_CTX(&c), handle));
|
||||
|
||||
TF_DimensionHandle* dim_handle = TF_NewDimensionHandle();
|
||||
TF_ShapeInferenceContextDim(C_CTX(&c), handle, 0, dim_handle);
|
||||
ASSERT_EQ(1, TF_DimensionHandleValueKnown(dim_handle));
|
||||
ASSERT_EQ(43, TF_DimensionHandleValue(dim_handle));
|
||||
TF_DeleteShapeHandle(handle);
|
||||
TF_DeleteDimensionHandle(dim_handle);
|
||||
}
|
||||
|
||||
TEST(OpsTest, ShapeInferenceSubshape) {
|
||||
NodeDef def;
|
||||
shape_inference::InferenceContext c(0, &def, MakeOpDef(1, 0),
|
||||
{S({10, 20, 30, 40, 50})}, {}, {}, {});
|
||||
ASSERT_EQ("[10,20,30,40,50]", c.DebugString(c.input(0)));
|
||||
|
||||
TF_ShapeHandle* handle = TF_NewShapeHandle();
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_ShapeInferenceContextGetInput(C_CTX(&c), 0, handle, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status));
|
||||
TF_ShapeInferenceContextSubshape(C_CTX(&c), handle, 1, -1, handle, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status));
|
||||
ASSERT_EQ(
|
||||
"[20,30,40]",
|
||||
c.DebugString(*reinterpret_cast<shape_inference::ShapeHandle*>(handle)));
|
||||
TF_DeleteStatus(status);
|
||||
TF_DeleteShapeHandle(handle);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
Loading…
x
Reference in New Issue
Block a user