STT-tensorflow/tensorflow/c/ops.cc
James Ring 3c9d8867a3 Don't call ProcessRegistrations in TF_RegisterOpDefinition
Before this change, TF_RegisterOpDefinition called ProcessRegistrations so that
the newly created registration would be processed immediately. Unfortunately,
this will not happen if a previously deferred registration fails, as
ProcessRegistrations will stop immediately when it sees errors. In the worst
case, the following chain of events will occur:

- op "A" is registered with the C++ API, its registration is deferred
- op "B" is registered with the C API, ProcessRegistrations is called
- op "A"'s deferred registration is invalid for some reason,
  ProcessRegistrations returns before processing B's registration
- builder for op "B" is deleted, but it was never actually used
- the user queries for op "B" using OpRegistry::LookUp, causing the deferred
  registration to be processed. However, the builder for "B" was deleted by
  TF_RegisterOpDefinition when "B" was initially registered. The registration
  process now uses the op definition builder after it was freed.

After this change, TF_RegisterOpDefinition will always place TF_OK in
*status, even if the definition is invalid (e.g. an op with the same
name is already registered). This behavior, while arguably undesirable,
is consistent with REGISTER_OP in TF C++.

PiperOrigin-RevId: 255423852
2019-06-27 10:45:54 -07:00

266 lines
11 KiB
C++

/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/ops.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def_builder.h"
#include "tensorflow/core/framework/shape_inference.h"
using ::tensorflow::DataType;
using ::tensorflow::OpDef;
using ::tensorflow::OpDefBuilder;
using ::tensorflow::OpDeprecation;
using ::tensorflow::OpShapeInferenceFn;
using ::tensorflow::Set_TF_Status_from_Status;
using ::tensorflow::Status;
using ::tensorflow::shape_inference::DimensionHandle;
using ::tensorflow::shape_inference::InferenceContext;
using ::tensorflow::shape_inference::ShapeHandle;
TF_OpDefinitionBuilder* TF_NewOpDefinitionBuilder(const char* op_name) {
auto* result = new OpDefBuilder(op_name);
return reinterpret_cast<TF_OpDefinitionBuilder*>(result);
}
void TF_DeleteOpDefinitionBuilder(TF_OpDefinitionBuilder* builder) {
delete reinterpret_cast<OpDefBuilder*>(builder);
}
void TF_OpDefinitionBuilderAddInput(TF_OpDefinitionBuilder* builder,
const char* input_spec) {
reinterpret_cast<OpDefBuilder*>(builder)->Input(input_spec);
}
void TF_OpDefinitionBuilderAddOutput(TF_OpDefinitionBuilder* builder,
const char* output_spec) {
reinterpret_cast<OpDefBuilder*>(builder)->Output(output_spec);
}
#define DEFINE_BUILDER_BOOL_SETTER(func_name) \
void TF_OpDefinitionBuilder##func_name(TF_OpDefinitionBuilder* builder, \
bool arg_name) { \
reinterpret_cast<OpDefBuilder*>(builder)->func_name(); \
}
DEFINE_BUILDER_BOOL_SETTER(SetIsCommutative)
DEFINE_BUILDER_BOOL_SETTER(SetIsAggregate)
DEFINE_BUILDER_BOOL_SETTER(SetIsStateful)
DEFINE_BUILDER_BOOL_SETTER(SetAllowsUninitializedInput)
void TF_OpDefinitionBuilderAddAttr(TF_OpDefinitionBuilder* builder,
const char* attr_spec) {
reinterpret_cast<OpDefBuilder*>(builder)->Attr(attr_spec);
}
void TF_OpDefinitionBuilderDeprecated(TF_OpDefinitionBuilder* builder,
int version, const char* explanation) {
reinterpret_cast<OpDefBuilder*>(builder)->Deprecated(version, explanation);
}
void TF_RegisterOpDefinition(TF_OpDefinitionBuilder* builder,
TF_Status* status) {
auto* cc_builder = reinterpret_cast<OpDefBuilder*>(builder);
TF_SetStatus(status, TF_OK, "");
::tensorflow::OpRegistry::Global()->Register(
[cc_builder](::tensorflow::OpRegistrationData* op_reg_data) -> Status {
Status result = cc_builder->Finalize(op_reg_data);
delete cc_builder;
return result;
});
}
void TF_OpDefinitionBuilderSetShapeInferenceFunction(
TF_OpDefinitionBuilder* builder,
void (*shape_inference_func)(TF_ShapeInferenceContext* ctx,
TF_Status* status)) {
auto* cc_builder = reinterpret_cast<OpDefBuilder*>(builder);
cc_builder->SetShapeFn(
[shape_inference_func](InferenceContext* ctx) -> tensorflow::Status {
TF_Status* c_status = TF_NewStatus();
auto c_ctx = reinterpret_cast<TF_ShapeInferenceContext*>(ctx);
shape_inference_func(c_ctx, c_status);
tensorflow::Status result = ::tensorflow::StatusFromTF_Status(c_status);
TF_DeleteStatus(c_status);
return result;
});
}
TF_ShapeHandle* TF_NewShapeHandle() {
return reinterpret_cast<TF_ShapeHandle*>(new ShapeHandle);
}
TF_ShapeHandle* TF_ShapeInferenceContextVectorFromSize(
TF_ShapeInferenceContext* ctx, size_t size) {
auto* handle = new ShapeHandle;
*handle = reinterpret_cast<InferenceContext*>(ctx)->Vector(size);
return reinterpret_cast<TF_ShapeHandle*>(handle);
}
void TF_ShapeInferenceContextConcatenateShapes(TF_ShapeInferenceContext* ctx,
TF_ShapeHandle* first,
TF_ShapeHandle* second,
TF_ShapeHandle* result,
TF_Status* status) {
auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx);
Status s = cc_ctx->Concatenate(*reinterpret_cast<ShapeHandle*>(first),
*reinterpret_cast<ShapeHandle*>(second),
reinterpret_cast<ShapeHandle*>(result));
Set_TF_Status_from_Status(status, s);
}
TF_DimensionHandle* TF_NewDimensionHandle() {
return reinterpret_cast<TF_DimensionHandle*>(new DimensionHandle);
}
int64_t TF_ShapeInferenceContextNumInputs(TF_ShapeInferenceContext* ctx) {
auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx);
return cc_ctx->num_inputs();
}
void TF_ShapeInferenceContextGetInput(TF_ShapeInferenceContext* ctx, int i,
TF_ShapeHandle* handle,
TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx);
if (0 < i || i >= cc_ctx->num_inputs()) {
TF_SetStatus(status, TF_INVALID_ARGUMENT, "input index out of range");
}
if (TF_GetCode(status) == TF_OK) {
auto* cc_result = reinterpret_cast<ShapeHandle*>(handle);
*cc_result = cc_ctx->input(i);
}
}
int TF_ShapeInferenceContextRankKnown(TF_ShapeInferenceContext* ctx,
TF_ShapeHandle* handle) {
auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx);
return cc_ctx->RankKnown(*reinterpret_cast<ShapeHandle*>(handle));
}
void TF_ShapeInferenceContextSetOutput(TF_ShapeInferenceContext* ctx, int i,
TF_ShapeHandle* handle,
TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx);
if (0 < i || i >= cc_ctx->num_outputs()) {
TF_SetStatus(status, TF_INVALID_ARGUMENT, "output index out of range");
}
if (TF_GetCode(status) == TF_OK) {
cc_ctx->set_output(i, *(reinterpret_cast<ShapeHandle*>(handle)));
}
}
void TF_DeleteShapeHandle(TF_ShapeHandle* handle) {
if (handle == nullptr) {
return;
}
delete reinterpret_cast<ShapeHandle*>(handle);
}
void TF_DeleteDimensionHandle(TF_DimensionHandle* handle) {
if (handle == nullptr) {
return;
}
delete reinterpret_cast<DimensionHandle*>(handle);
}
#define DEFINE_TF_GETATTR(func, c_type, cc_type) \
void TF_ShapeInferenceContext_GetAttr##func( \
TF_ShapeInferenceContext* ctx, const char* attr_name, c_type* val, \
TF_Status* status) { \
TF_SetStatus(status, TF_OK, ""); \
cc_type v; \
auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx); \
Status s = cc_ctx->GetAttr(attr_name, &v); \
Set_TF_Status_from_Status(status, s); \
if (s.ok()) { \
*val = static_cast<c_type>(v); \
} \
}
DEFINE_TF_GETATTR(Type, TF_DataType, tensorflow::DataType)
#define DEFINE_RANK_FUNC(func_name) \
void TF_ShapeInferenceContext##func_name( \
TF_ShapeInferenceContext* ctx, TF_ShapeHandle* handle, int64_t rank, \
TF_ShapeHandle* result, TF_Status* status) { \
auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx); \
auto* cc_handle = reinterpret_cast<ShapeHandle*>(handle); \
auto* cc_result = reinterpret_cast<ShapeHandle*>(result); \
Status s = cc_ctx->func_name(*cc_handle, rank, cc_result); \
Set_TF_Status_from_Status(status, s); \
}
DEFINE_RANK_FUNC(WithRank)
DEFINE_RANK_FUNC(WithRankAtLeast)
DEFINE_RANK_FUNC(WithRankAtMost)
int64_t TF_ShapeInferenceContextRank(TF_ShapeInferenceContext* ctx,
TF_ShapeHandle* handle) {
return reinterpret_cast<InferenceContext*>(ctx)->Rank(
*reinterpret_cast<ShapeHandle*>(handle));
}
void TF_ShapeInferenceContextDim(TF_ShapeInferenceContext* ctx,
TF_ShapeHandle* shape_handle, int64_t i,
TF_DimensionHandle* result) {
int64_t rank = TF_ShapeInferenceContextRank(ctx, shape_handle);
auto* cc_result = reinterpret_cast<DimensionHandle*>(result);
if (i < -rank || i >= rank) {
*cc_result = DimensionHandle();
return;
}
auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx);
auto* cc_shape_handle = reinterpret_cast<ShapeHandle*>(shape_handle);
*cc_result = cc_ctx->Dim(*cc_shape_handle, i);
}
int TF_DimensionHandleValueKnown(TF_DimensionHandle* dim_handle) {
return InferenceContext::ValueKnown(
*reinterpret_cast<DimensionHandle*>(dim_handle));
}
void TF_ShapeInferenceContextSetUnknownShape(TF_ShapeInferenceContext* ctx,
TF_Status* status) {
Status s = ::tensorflow::shape_inference::UnknownShape(
reinterpret_cast<InferenceContext*>(ctx));
Set_TF_Status_from_Status(status, s);
}
void TF_ShapeInferenceContextSubshape(TF_ShapeInferenceContext* ctx,
TF_ShapeHandle* shape_handle,
int64_t start, int64_t end,
TF_ShapeHandle* result,
TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
auto* cc_ctx = reinterpret_cast<InferenceContext*>(ctx);
auto* cc_result = reinterpret_cast<ShapeHandle*>(result);
Status s = cc_ctx->Subshape(*reinterpret_cast<ShapeHandle*>(shape_handle),
start, end, cc_result);
Set_TF_Status_from_Status(status, s);
}
int64_t TF_DimensionHandleValue(TF_DimensionHandle* dim_handle) {
return InferenceContext::Value(
*reinterpret_cast<DimensionHandle*>(dim_handle));
}