Refactor AbstractionFunction implementation to make it abstract on the C++ side as well

This is an interesting exercise as it is surfacing some coupling between eager
and graph. This isn't surprising as this coupling existing in the existing runtime,
but as we're trying to untangle every pieces such refactoring is forcing to expose
the coupling by introducing a new API to extract a TF_Function out of the
AbstractFunction.

PiperOrigin-RevId: 308423040
Change-Id: I49f4fed4cb6d1c82d4b7af33cb9d5eac7edcf703
This commit is contained in:
Mehdi Amini 2020-04-25 10:07:35 -07:00 committed by TensorFlower Gardener
parent 4e2e046484
commit 81d7cee332
2 changed files with 39 additions and 10 deletions

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/platform/strcat.h"
using tensorflow::string;
using tensorflow::internal::AbstractFunction;
using tensorflow::internal::AbstractOp;
using tensorflow::internal::AbstractTensor;
using tensorflow::internal::dynamic_cast_helper;
@ -147,10 +148,18 @@ class TF_EagerOp : public AbstractOp {
TFE_Context* ctx_;
};
struct TF_AbstractFunction {
struct GraphFunction : public AbstractFunction {
TF_Function* func = nullptr;
GraphFunction() : AbstractFunction(kKind) {}
explicit GraphFunction(TF_Function* func)
: AbstractFunction(kKind), func(func) {}
~GraphFunction() override {
if (func) TF_DeleteFunction(func);
}
~TF_AbstractFunction() { TF_DeleteFunction(func); }
TF_Function* GetTfFunction(TF_Status* s) override { return func; }
static constexpr AbstractFunctionKind kKind = kGraphFunc;
};
class TF_EagerContext : public ExecutionContext {
@ -207,8 +216,12 @@ class TF_EagerContext : public ExecutionContext {
}
}
void RegisterFunction(TF_AbstractFunction* func, TF_Status* s) override {
TFE_ContextAddFunction(eager_ctx_, func->func, s);
void RegisterFunction(AbstractFunction* afunc, TF_Status* s) override {
auto* func = afunc->GetTfFunction(s);
if (!func) {
return;
}
TFE_ContextAddFunction(eager_ctx_, func, s);
}
~TF_EagerContext() override { TFE_DeleteContext(eager_ctx_); }
@ -291,7 +304,7 @@ class TF_GraphContext : public ExecutionContext {
nullptr, nullptr, fn_name, status);
}
void RegisterFunction(TF_AbstractFunction* func, TF_Status* s) override {
void RegisterFunction(AbstractFunction* func, TF_Status* s) override {
TF_SetStatus(s, TF_UNIMPLEMENTED,
"Registering graph functions has not been implemented yet.");
}
@ -369,18 +382,20 @@ TF_AbstractFunction* TF_ExecutionContextToFunction(
TF_SetStatus(status, TF_INVALID_ARGUMENT, "outputs aren't GraphTensors.");
return nullptr;
}
TF_AbstractFunction* func = new TF_AbstractFunction;
GraphFunction* func = new GraphFunction;
func->func = graph_ctx->ToFunction(fn_name, num_inputs, graph_inputs,
num_outputs, graph_outputs, status);
return func;
return wrap(func);
}
void TF_DeleteAbstractFunction(TF_AbstractFunction* func) { delete func; }
void TF_DeleteAbstractFunction(TF_AbstractFunction* func) {
delete unwrap(func);
}
void TF_ExecutionContextRegisterFunction(TF_ExecutionContext* ctx,
TF_AbstractFunction* func,
TF_Status* s) {
unwrap(ctx)->RegisterFunction(func, s);
unwrap(ctx)->RegisterFunction(unwrap(func), s);
}
TF_AbstractTensor* TF_CreateAbstractTensorFromEagerTensor(TFE_TensorHandle* t,

View File

@ -43,6 +43,19 @@ struct OutputList {
int expected_num_outputs = -1;
};
struct AbstractFunction {
enum AbstractFunctionKind { kGraphFunc };
explicit AbstractFunction(AbstractFunctionKind kind) : k(kind) {}
AbstractFunctionKind getKind() const { return k; }
virtual ~AbstractFunction() = default;
// Temporary API till we figure the right abstraction for AbstractFunction
virtual TF_Function* GetTfFunction(TF_Status* s) = 0;
private:
const AbstractFunctionKind k;
};
struct AbstractOp {
// Needed to implement our own version of RTTI since dynamic_cast is not
// supported in mobile builds.
@ -70,7 +83,7 @@ struct ExecutionContext {
AbstractTensor* const* inputs, OutputList* o,
TF_Status* s) = 0;
virtual AbstractOp* CreateOperation() = 0;
virtual void RegisterFunction(TF_AbstractFunction* func, TF_Status* s) = 0;
virtual void RegisterFunction(AbstractFunction* func, TF_Status* s) = 0;
virtual ~ExecutionContext() = default;
private:
@ -94,6 +107,7 @@ struct ExecutionContext {
}
MAKE_WRAP_UNWRAP(TF_ExecutionContext, ExecutionContext)
MAKE_WRAP_UNWRAP(TF_AbstractFunction, AbstractFunction)
MAKE_WRAP_UNWRAP(TF_AbstractTensor, AbstractTensor)
MAKE_WRAP_UNWRAP(TF_AbstractOp, AbstractOp)
MAKE_WRAP_UNWRAP(TF_OutputList, OutputList)