266 lines
11 KiB
C++
266 lines
11 KiB
C++
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include "tensorflow/c/ops.h"
|
|
|
|
#include "tensorflow/c/tf_status_helper.h"
|
|
#include "tensorflow/core/framework/common_shape_fns.h"
|
|
#include "tensorflow/core/framework/op.h"
|
|
#include "tensorflow/core/framework/op_def_builder.h"
|
|
#include "tensorflow/core/framework/shape_inference.h"
|
|
|
|
using ::tensorflow::DataType;
|
|
using ::tensorflow::OpDef;
|
|
using ::tensorflow::OpDefBuilder;
|
|
using ::tensorflow::OpDeprecation;
|
|
using ::tensorflow::OpShapeInferenceFn;
|
|
using ::tensorflow::Set_TF_Status_from_Status;
|
|
using ::tensorflow::Status;
|
|
using ::tensorflow::shape_inference::DimensionHandle;
|
|
using ::tensorflow::shape_inference::InferenceContext;
|
|
using ::tensorflow::shape_inference::ShapeHandle;
|
|
|
|
TF_OpDefinitionBuilder* TF_NewOpDefinitionBuilder(const char* op_name) {
|
|
auto* result = new OpDefBuilder(op_name);
|
|
return reinterpret_cast<TF_OpDefinitionBuilder*>(result);
|
|
}
|
|
|
|
void TF_DeleteOpDefinitionBuilder(TF_OpDefinitionBuilder* builder) {
|
|
delete reinterpret_cast<OpDefBuilder*>(builder);
|
|
}
|
|
|
|
void TF_OpDefinitionBuilderAddInput(TF_OpDefinitionBuilder* builder,
|
|
const char* input_spec) {
|
|
reinterpret_cast<OpDefBuilder*>(builder)->Input(input_spec);
|
|
}
|
|
|
|
void TF_OpDefinitionBuilderAddOutput(TF_OpDefinitionBuilder* builder,
|
|
const char* output_spec) {
|
|
reinterpret_cast<OpDefBuilder*>(builder)->Output(output_spec);
|
|
}
|
|
|
|
#define DEFINE_BUILDER_BOOL_SETTER(func_name) \
|
|
void TF_OpDefinitionBuilder##func_name(TF_OpDefinitionBuilder* builder, \
|
|
bool arg_name) { \
|
|
reinterpret_cast<OpDefBuilder*>(builder)->func_name(); \
|
|
}
|
|
|
|
DEFINE_BUILDER_BOOL_SETTER(SetIsCommutative)
|
|
DEFINE_BUILDER_BOOL_SETTER(SetIsAggregate)
|
|
DEFINE_BUILDER_BOOL_SETTER(SetIsStateful)
|
|
DEFINE_BUILDER_BOOL_SETTER(SetAllowsUninitializedInput)
|
|
|
|
void TF_OpDefinitionBuilderAddAttr(TF_OpDefinitionBuilder* builder,
|
|
const char* attr_spec) {
|
|
reinterpret_cast<OpDefBuilder*>(builder)->Attr(attr_spec);
|
|
}
|
|
|
|
void TF_OpDefinitionBuilderDeprecated(TF_OpDefinitionBuilder* builder,
|
|
int version, const char* explanation) {
|
|
reinterpret_cast<OpDefBuilder*>(builder)->Deprecated(version, explanation);
|
|
}
|
|
|
|
void TF_RegisterOpDefinition(TF_OpDefinitionBuilder* builder,
|
|
TF_Status* status) {
|
|
auto* cc_builder = reinterpret_cast<OpDefBuilder*>(builder);
|
|
TF_SetStatus(status, TF_OK, "");
|
|
::tensorflow::OpRegistry::Global()->Register(
|
|
[cc_builder](::tensorflow::OpRegistrationData* op_reg_data) -> Status {
|
|
Status result = cc_builder->Finalize(op_reg_data);
|
|
delete cc_builder;
|
|
return result;
|
|
});
|
|
}
|
|
|
|
void TF_OpDefinitionBuilderSetShapeInferenceFunction(
|
|
TF_OpDefinitionBuilder* builder,
|
|
void (*shape_inference_func)(TF_ShapeInferenceContext* ctx,
|
|
TF_Status* status)) {
|
|
auto* cc_builder = reinterpret_cast<OpDefBuilder*>(builder);
|
|
cc_builder->SetShapeFn(
|
|
[shape_inference_func](InferenceContext* ctx) -> tensorflow::Status {
|
|
TF_Status* c_status = TF_NewStatus();
|
|
auto c_ctx = reinterpret_cast<TF_ShapeInferenceContext*>(ctx);
|
|
shape_inference_func(c_ctx, c_status);
|
|
tensorflow::Status result = ::tensorflow::StatusFromTF_Status(c_status);
|
|
TF_DeleteStatus(c_status);
|
|
return result;
|
|
});
|
|
}
|
|
|
|
TF_ShapeHandle* TF_NewShapeHandle() {
|
|
return reinterpret_cast<TF_ShapeHandle*>(new ShapeHandle);
|
|
}
|
|
|
|
TF_ShapeHandle* TF_ShapeInferenceContextVectorFromSize(
|
|
TF_ShapeInferenceContext* ctx, size_t size) {
|
|
auto* handle = new ShapeHandle;
|
|
*handle = reinterpret_cast<InferenceContext*>(ctx)->Vector(size);
|
|
return reinterpret_cast<TF_ShapeHandle*>(handle);
|
|
}
|
|
|
|
void TF_ShapeInferenceContextConcatenateShapes(TF_ShapeInferenceContext* ctx,
|
|
TF_ShapeHandle* first,
|
|
TF_ShapeHandle* second,
|
|
TF_ShapeHandle* result,
|
|
TF_Status* status) {
|
|
auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx);
|
|
Status s = cc_ctx->Concatenate(*reinterpret_cast<ShapeHandle*>(first),
|
|
*reinterpret_cast<ShapeHandle*>(second),
|
|
reinterpret_cast<ShapeHandle*>(result));
|
|
Set_TF_Status_from_Status(status, s);
|
|
}
|
|
|
|
TF_DimensionHandle* TF_NewDimensionHandle() {
|
|
return reinterpret_cast<TF_DimensionHandle*>(new DimensionHandle);
|
|
}
|
|
|
|
int64_t TF_ShapeInferenceContextNumInputs(TF_ShapeInferenceContext* ctx) {
|
|
auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx);
|
|
return cc_ctx->num_inputs();
|
|
}
|
|
|
|
void TF_ShapeInferenceContextGetInput(TF_ShapeInferenceContext* ctx, int i,
|
|
TF_ShapeHandle* handle,
|
|
TF_Status* status) {
|
|
TF_SetStatus(status, TF_OK, "");
|
|
auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx);
|
|
if (0 < i || i >= cc_ctx->num_inputs()) {
|
|
TF_SetStatus(status, TF_INVALID_ARGUMENT, "input index out of range");
|
|
}
|
|
if (TF_GetCode(status) == TF_OK) {
|
|
auto* cc_result = reinterpret_cast<ShapeHandle*>(handle);
|
|
*cc_result = cc_ctx->input(i);
|
|
}
|
|
}
|
|
|
|
int TF_ShapeInferenceContextRankKnown(TF_ShapeInferenceContext* ctx,
|
|
TF_ShapeHandle* handle) {
|
|
auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx);
|
|
return cc_ctx->RankKnown(*reinterpret_cast<ShapeHandle*>(handle));
|
|
}
|
|
|
|
void TF_ShapeInferenceContextSetOutput(TF_ShapeInferenceContext* ctx, int i,
|
|
TF_ShapeHandle* handle,
|
|
TF_Status* status) {
|
|
TF_SetStatus(status, TF_OK, "");
|
|
auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx);
|
|
if (0 < i || i >= cc_ctx->num_outputs()) {
|
|
TF_SetStatus(status, TF_INVALID_ARGUMENT, "output index out of range");
|
|
}
|
|
if (TF_GetCode(status) == TF_OK) {
|
|
cc_ctx->set_output(i, *(reinterpret_cast<ShapeHandle*>(handle)));
|
|
}
|
|
}
|
|
|
|
void TF_DeleteShapeHandle(TF_ShapeHandle* handle) {
|
|
if (handle == nullptr) {
|
|
return;
|
|
}
|
|
|
|
delete reinterpret_cast<ShapeHandle*>(handle);
|
|
}
|
|
|
|
void TF_DeleteDimensionHandle(TF_DimensionHandle* handle) {
|
|
if (handle == nullptr) {
|
|
return;
|
|
}
|
|
|
|
delete reinterpret_cast<DimensionHandle*>(handle);
|
|
}
|
|
|
|
#define DEFINE_TF_GETATTR(func, c_type, cc_type) \
|
|
void TF_ShapeInferenceContext_GetAttr##func( \
|
|
TF_ShapeInferenceContext* ctx, const char* attr_name, c_type* val, \
|
|
TF_Status* status) { \
|
|
TF_SetStatus(status, TF_OK, ""); \
|
|
cc_type v; \
|
|
auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx); \
|
|
Status s = cc_ctx->GetAttr(attr_name, &v); \
|
|
Set_TF_Status_from_Status(status, s); \
|
|
if (s.ok()) { \
|
|
*val = static_cast<c_type>(v); \
|
|
} \
|
|
}
|
|
|
|
DEFINE_TF_GETATTR(Type, TF_DataType, tensorflow::DataType)
|
|
|
|
#define DEFINE_RANK_FUNC(func_name) \
|
|
void TF_ShapeInferenceContext##func_name( \
|
|
TF_ShapeInferenceContext* ctx, TF_ShapeHandle* handle, int64_t rank, \
|
|
TF_ShapeHandle* result, TF_Status* status) { \
|
|
auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx); \
|
|
auto* cc_handle = reinterpret_cast<ShapeHandle*>(handle); \
|
|
auto* cc_result = reinterpret_cast<ShapeHandle*>(result); \
|
|
Status s = cc_ctx->func_name(*cc_handle, rank, cc_result); \
|
|
Set_TF_Status_from_Status(status, s); \
|
|
}
|
|
|
|
DEFINE_RANK_FUNC(WithRank)
|
|
DEFINE_RANK_FUNC(WithRankAtLeast)
|
|
DEFINE_RANK_FUNC(WithRankAtMost)
|
|
|
|
int64_t TF_ShapeInferenceContextRank(TF_ShapeInferenceContext* ctx,
|
|
TF_ShapeHandle* handle) {
|
|
return reinterpret_cast<InferenceContext*>(ctx)->Rank(
|
|
*reinterpret_cast<ShapeHandle*>(handle));
|
|
}
|
|
|
|
void TF_ShapeInferenceContextDim(TF_ShapeInferenceContext* ctx,
|
|
TF_ShapeHandle* shape_handle, int64_t i,
|
|
TF_DimensionHandle* result) {
|
|
int64_t rank = TF_ShapeInferenceContextRank(ctx, shape_handle);
|
|
auto* cc_result = reinterpret_cast<DimensionHandle*>(result);
|
|
|
|
if (i < -rank || i >= rank) {
|
|
*cc_result = DimensionHandle();
|
|
return;
|
|
}
|
|
|
|
auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx);
|
|
auto* cc_shape_handle = reinterpret_cast<ShapeHandle*>(shape_handle);
|
|
*cc_result = cc_ctx->Dim(*cc_shape_handle, i);
|
|
}
|
|
|
|
int TF_DimensionHandleValueKnown(TF_DimensionHandle* dim_handle) {
|
|
return InferenceContext::ValueKnown(
|
|
*reinterpret_cast<DimensionHandle*>(dim_handle));
|
|
}
|
|
|
|
void TF_ShapeInferenceContextSetUnknownShape(TF_ShapeInferenceContext* ctx,
|
|
TF_Status* status) {
|
|
Status s = ::tensorflow::shape_inference::UnknownShape(
|
|
reinterpret_cast<InferenceContext*>(ctx));
|
|
Set_TF_Status_from_Status(status, s);
|
|
}
|
|
|
|
void TF_ShapeInferenceContextSubshape(TF_ShapeInferenceContext* ctx,
|
|
TF_ShapeHandle* shape_handle,
|
|
int64_t start, int64_t end,
|
|
TF_ShapeHandle* result,
|
|
TF_Status* status) {
|
|
TF_SetStatus(status, TF_OK, "");
|
|
auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx);
|
|
auto* cc_result = reinterpret_cast<ShapeHandle*>(result);
|
|
Status s = cc_ctx->Subshape(*reinterpret_cast<ShapeHandle*>(shape_handle),
|
|
start, end, cc_result);
|
|
Set_TF_Status_from_Status(status, s);
|
|
}
|
|
|
|
int64_t TF_DimensionHandleValue(TF_DimensionHandle* dim_handle) {
|
|
return InferenceContext::Value(
|
|
*reinterpret_cast<DimensionHandle*>(dim_handle));
|
|
}
|