Refactor AbstractionFunction implementation to make it abstract on the C++ side as well
This is an interesting exercise as it is surfacing some coupling between eager and graph. This isn't surprising as this coupling existing in the existing runtime, but as we're trying to untangle every pieces such refactoring is forcing to expose the coupling by introducing a new API to extract a TF_Function out of the AbstractFunction. PiperOrigin-RevId: 308423040 Change-Id: I49f4fed4cb6d1c82d4b7af33cb9d5eac7edcf703
This commit is contained in:
parent
4e2e046484
commit
81d7cee332
@ -28,6 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/strcat.h"
|
||||
|
||||
using tensorflow::string;
|
||||
using tensorflow::internal::AbstractFunction;
|
||||
using tensorflow::internal::AbstractOp;
|
||||
using tensorflow::internal::AbstractTensor;
|
||||
using tensorflow::internal::dynamic_cast_helper;
|
||||
@ -147,10 +148,18 @@ class TF_EagerOp : public AbstractOp {
|
||||
TFE_Context* ctx_;
|
||||
};
|
||||
|
||||
struct TF_AbstractFunction {
|
||||
struct GraphFunction : public AbstractFunction {
|
||||
TF_Function* func = nullptr;
|
||||
GraphFunction() : AbstractFunction(kKind) {}
|
||||
explicit GraphFunction(TF_Function* func)
|
||||
: AbstractFunction(kKind), func(func) {}
|
||||
~GraphFunction() override {
|
||||
if (func) TF_DeleteFunction(func);
|
||||
}
|
||||
|
||||
~TF_AbstractFunction() { TF_DeleteFunction(func); }
|
||||
TF_Function* GetTfFunction(TF_Status* s) override { return func; }
|
||||
|
||||
static constexpr AbstractFunctionKind kKind = kGraphFunc;
|
||||
};
|
||||
|
||||
class TF_EagerContext : public ExecutionContext {
|
||||
@ -207,8 +216,12 @@ class TF_EagerContext : public ExecutionContext {
|
||||
}
|
||||
}
|
||||
|
||||
void RegisterFunction(TF_AbstractFunction* func, TF_Status* s) override {
|
||||
TFE_ContextAddFunction(eager_ctx_, func->func, s);
|
||||
void RegisterFunction(AbstractFunction* afunc, TF_Status* s) override {
|
||||
auto* func = afunc->GetTfFunction(s);
|
||||
if (!func) {
|
||||
return;
|
||||
}
|
||||
TFE_ContextAddFunction(eager_ctx_, func, s);
|
||||
}
|
||||
|
||||
~TF_EagerContext() override { TFE_DeleteContext(eager_ctx_); }
|
||||
@ -291,7 +304,7 @@ class TF_GraphContext : public ExecutionContext {
|
||||
nullptr, nullptr, fn_name, status);
|
||||
}
|
||||
|
||||
void RegisterFunction(TF_AbstractFunction* func, TF_Status* s) override {
|
||||
void RegisterFunction(AbstractFunction* func, TF_Status* s) override {
|
||||
TF_SetStatus(s, TF_UNIMPLEMENTED,
|
||||
"Registering graph functions has not been implemented yet.");
|
||||
}
|
||||
@ -369,18 +382,20 @@ TF_AbstractFunction* TF_ExecutionContextToFunction(
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT, "outputs aren't GraphTensors.");
|
||||
return nullptr;
|
||||
}
|
||||
TF_AbstractFunction* func = new TF_AbstractFunction;
|
||||
GraphFunction* func = new GraphFunction;
|
||||
func->func = graph_ctx->ToFunction(fn_name, num_inputs, graph_inputs,
|
||||
num_outputs, graph_outputs, status);
|
||||
return func;
|
||||
return wrap(func);
|
||||
}
|
||||
|
||||
void TF_DeleteAbstractFunction(TF_AbstractFunction* func) { delete func; }
|
||||
void TF_DeleteAbstractFunction(TF_AbstractFunction* func) {
|
||||
delete unwrap(func);
|
||||
}
|
||||
|
||||
void TF_ExecutionContextRegisterFunction(TF_ExecutionContext* ctx,
|
||||
TF_AbstractFunction* func,
|
||||
TF_Status* s) {
|
||||
unwrap(ctx)->RegisterFunction(func, s);
|
||||
unwrap(ctx)->RegisterFunction(unwrap(func), s);
|
||||
}
|
||||
|
||||
TF_AbstractTensor* TF_CreateAbstractTensorFromEagerTensor(TFE_TensorHandle* t,
|
||||
|
@ -43,6 +43,19 @@ struct OutputList {
|
||||
int expected_num_outputs = -1;
|
||||
};
|
||||
|
||||
struct AbstractFunction {
|
||||
enum AbstractFunctionKind { kGraphFunc };
|
||||
explicit AbstractFunction(AbstractFunctionKind kind) : k(kind) {}
|
||||
AbstractFunctionKind getKind() const { return k; }
|
||||
virtual ~AbstractFunction() = default;
|
||||
|
||||
// Temporary API till we figure the right abstraction for AbstractFunction
|
||||
virtual TF_Function* GetTfFunction(TF_Status* s) = 0;
|
||||
|
||||
private:
|
||||
const AbstractFunctionKind k;
|
||||
};
|
||||
|
||||
struct AbstractOp {
|
||||
// Needed to implement our own version of RTTI since dynamic_cast is not
|
||||
// supported in mobile builds.
|
||||
@ -70,7 +83,7 @@ struct ExecutionContext {
|
||||
AbstractTensor* const* inputs, OutputList* o,
|
||||
TF_Status* s) = 0;
|
||||
virtual AbstractOp* CreateOperation() = 0;
|
||||
virtual void RegisterFunction(TF_AbstractFunction* func, TF_Status* s) = 0;
|
||||
virtual void RegisterFunction(AbstractFunction* func, TF_Status* s) = 0;
|
||||
virtual ~ExecutionContext() = default;
|
||||
|
||||
private:
|
||||
@ -94,6 +107,7 @@ struct ExecutionContext {
|
||||
}
|
||||
|
||||
MAKE_WRAP_UNWRAP(TF_ExecutionContext, ExecutionContext)
|
||||
MAKE_WRAP_UNWRAP(TF_AbstractFunction, AbstractFunction)
|
||||
MAKE_WRAP_UNWRAP(TF_AbstractTensor, AbstractTensor)
|
||||
MAKE_WRAP_UNWRAP(TF_AbstractOp, AbstractOp)
|
||||
MAKE_WRAP_UNWRAP(TF_OutputList, OutputList)
|
||||
|
Loading…
Reference in New Issue
Block a user