diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index da7ddc3ec06..9696a3415bf 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -158,9 +158,13 @@ cc_library( "//tensorflow:internal", ], deps = [ + ":abstract_context", + ":abstract_operation", + ":abstract_tensor_handle", ":c_api", ":c_api_experimental", "//tensorflow/c:c_api_internal", + "//tensorflow/c:conversion_macros", "//tensorflow/c:tf_status", "//tensorflow/core/platform:casts", "//tensorflow/core/platform:types", @@ -541,6 +545,9 @@ tf_cuda_library( ":abstract_operation", ":abstract_context", ":abstract_tensor_handle", + ":immediate_execution_tensor_handle", + ":immediate_execution_context", + "//tensorflow/core/lib/llvm_rtti", "//tensorflow/c:c_api", "//tensorflow/c:c_api_internal", "//tensorflow/core:core_cpu", @@ -559,6 +566,7 @@ tf_cuda_library( "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/types:variant", + "//tensorflow/c:conversion_macros", ], }) + select({ "//tensorflow:with_xla_support": [ @@ -732,6 +740,10 @@ filegroup( ], exclude = [ "c_api_experimental.cc", + "c_api_unified_experimental.cc", + "c_api_unified_experimental_eager.cc", + "c_api_unified_experimental_graph.cc", + "c_api_unified_experimental_internal.h", "*test*", "*dlpack*", ], diff --git a/tensorflow/c/eager/abstract_function.h b/tensorflow/c/eager/abstract_function.h index 303dd435c05..f02bc97b28c 100644 --- a/tensorflow/c/eager/abstract_function.h +++ b/tensorflow/c/eager/abstract_function.h @@ -25,7 +25,7 @@ namespace tensorflow { // function. class AbstractFunction { protected: - enum AbstractFunctionKind { kGraphFunc, kMlirFunc }; + enum AbstractFunctionKind { kGraph, kMlir }; explicit AbstractFunction(AbstractFunctionKind kind) : kind_(kind) {} public: diff --git a/tensorflow/c/eager/c_api_unified_experimental.cc b/tensorflow/c/eager/c_api_unified_experimental.cc index e5030a602b3..605a60c186c 100644 --- a/tensorflow/c/eager/c_api_unified_experimental.cc +++ b/tensorflow/c/eager/c_api_unified_experimental.cc @@ -22,15 +22,17 @@ limitations under the License. #include "tensorflow/c/eager/c_api_unified_experimental_internal.h" #include "tensorflow/c/tf_datatype.h" #include "tensorflow/c/tf_status.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/types.h" using tensorflow::string; -using tensorflow::internal::OutputList; -using tensorflow::internal::unwrap; namespace tensorflow { -namespace internal { -typedef absl::flat_hash_map<std::string, FactoryFunction> FactoriesMap; +namespace tracing { +typedef absl::flat_hash_map<std::string, tracing::FactoryFunction> FactoriesMap; static FactoriesMap& GetFactories() { static FactoriesMap* factories = new FactoriesMap; @@ -48,8 +50,8 @@ void RegisterTracingEngineFactory(const string& name, FactoryFunction factory) { void SetDefaultTracingEngine(const char* name) { default_factory = name; } -static ExecutionContext* CreateTracingExecutionContext(const char* fn_name, - TF_Status* s) { +static TracingContext* CreateTracingExecutionContext(const char* fn_name, + TF_Status* s) { auto entry = GetFactories().find(default_factory); if (entry != GetFactories().end()) return entry->second(fn_name, s); string msg = absl::StrCat( @@ -70,7 +72,7 @@ static ExecutionContext* CreateTracingExecutionContext(const char* fn_name, return nullptr; } -} // end namespace internal +} // end namespace tracing } // end namespace tensorflow // ============================================================================= @@ -83,43 +85,77 @@ static ExecutionContext* CreateTracingExecutionContext(const char* fn_name, // // ============================================================================= +using tensorflow::AbstractFunction; +using tensorflow::AbstractTensorHandle; +using tensorflow::DataType; +using tensorflow::dyn_cast; +using tensorflow::OutputList; +using tensorflow::Status; +using tensorflow::unwrap; +using tensorflow::wrap; +using tensorflow::tracing::CreateTracingExecutionContext; +using tensorflow::tracing::SetDefaultTracingEngine; +using tensorflow::tracing::TracingContext; +using tensorflow::tracing::TracingOperation; +using tensorflow::tracing::TracingTensorHandle; + void TF_SetTracingImplementation(const char* name) { - tensorflow::internal::SetDefaultTracingEngine(name); + SetDefaultTracingEngine(name); } // Creates a new TensorFlow function, it is an execution context attached to a // given tracing context. TF_ExecutionContext* TF_CreateFunction(const char* fn_name, TF_Status* s) { - return wrap(tensorflow::internal::CreateTracingExecutionContext(fn_name, s)); + return wrap(CreateTracingExecutionContext(fn_name, s)); } TF_AbstractFunction* TF_FinalizeFunction(TF_ExecutionContext* ctx, TF_OutputList* outputs, TF_Status* s) { - auto* func = wrap(unwrap(ctx)->Finalize(unwrap(outputs), s)); + AbstractFunction* func; + TracingContext* tracing_ctx = dyn_cast<TracingContext>(unwrap(ctx)); + if (!tracing_ctx) { + Set_TF_Status_from_Status( + s, tensorflow::errors::InvalidArgument( + "Only TracingContext can be converted into a function.")); + return nullptr; + } + Set_TF_Status_from_Status(s, tracing_ctx->Finalize(unwrap(outputs), &func)); TF_DeleteExecutionContext(ctx); - return func; + return wrap(func); } TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func, TF_DataType dtype, TF_Status* s) { - return wrap(unwrap(func)->AddParameter(dtype, s)); + TracingTensorHandle* t; + TracingContext* tracing_ctx = dyn_cast<TracingContext>(unwrap(func)); + if (!tracing_ctx) { + Set_TF_Status_from_Status( + s, tensorflow::errors::InvalidArgument( + "TF_AddFunctionParameter must be called on a TracingContext.")); + return nullptr; + } + Set_TF_Status_from_Status( + s, tracing_ctx->AddParameter(static_cast<DataType>(dtype), &t)); + return wrap(t); } -void TF_DeleteExecutionContext(TF_ExecutionContext* c) { delete unwrap(c); } +void TF_DeleteExecutionContext(TF_ExecutionContext* c) { unwrap(c)->Release(); } TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* c) { - return wrap(unwrap(c)->CreateOperation()); + return wrap((unwrap(c)->CreateOperation())); } -void TF_DeleteAbstractOp(TF_AbstractOp* op) { delete unwrap(op); } +void TF_DeleteAbstractOp(TF_AbstractOp* op) { unwrap(op)->Release(); } -void TF_DeleteAbstractTensor(TF_AbstractTensor* t) { delete unwrap(t); } +void TF_DeleteAbstractTensor(TF_AbstractTensor* t) { unwrap(t)->Release(); } TF_OutputList* TF_NewOutputList() { return wrap(new OutputList); } void TF_DeleteOutputList(TF_OutputList* o) { delete unwrap(o); } void TF_OutputListSetNumOutputs(TF_OutputList* o, int num_outputs, TF_Status* s) { unwrap(o)->expected_num_outputs = num_outputs; + unwrap(o)->outputs.clear(); + unwrap(o)->outputs.resize(num_outputs); } int TF_OutputListNumOutputs(TF_OutputList* o) { return unwrap(o)->outputs.size(); @@ -134,24 +170,46 @@ void TF_OutputListPushBack(TF_OutputList* o, TF_AbstractTensor* tensor, void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type, TF_Status* s) { - unwrap(op)->SetOpType(op_type, s); + Set_TF_Status_from_Status(s, unwrap(op)->Reset(op_type, + /*raw_device_name=*/nullptr)); } void TF_AbstractOpSetOpName(TF_AbstractOp* op, const char* const op_name, TF_Status* s) { - unwrap(op)->SetOpName(op_name, s); + TracingOperation* tracing_op = dyn_cast<TracingOperation>(unwrap(op)); + if (!tracing_op) { + Set_TF_Status_from_Status( + s, tensorflow::errors::InvalidArgument( + "TF_AbstractOpSetOpName must be called on a TracingOperation.")); + return; + } + Set_TF_Status_from_Status(s, tracing_op->SetOpName(op_name)); } void TF_AbstractOpSetAttrType(TF_AbstractOp* op, const char* const attr_name, TF_DataType value, TF_Status* s) { - unwrap(op)->SetAttrType(attr_name, value, s); + Status status = + unwrap(op)->SetAttrType(attr_name, static_cast<DataType>(value)); + TF_SetStatus(s, static_cast<TF_Code>(status.code()), + status.error_message().c_str()); } void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs, TF_AbstractTensor* const* inputs, TF_OutputList* o, - TF_ExecutionContext* ctx, TF_Status* s) { - unwrap(ctx)->ExecuteOperation(unwrap(op), num_inputs, &unwrap(*inputs), - unwrap(o), s); + TF_Status* s) { + for (int i = 0; i < num_inputs; i++) { + Set_TF_Status_from_Status(s, unwrap(op)->AddInput(unwrap(inputs[i]))); + if (TF_GetCode(s) != TF_OK) { + return; + } + } + int num_outputs = unwrap(o)->expected_num_outputs; + Set_TF_Status_from_Status( + s, unwrap(op)->Execute( + absl::MakeSpan(reinterpret_cast<AbstractTensorHandle**>( + unwrap(o)->outputs.data()), + unwrap(o)->outputs.size()), + &num_outputs)); } void TF_DeleteAbstractFunction(TF_AbstractFunction* func) { @@ -161,5 +219,5 @@ void TF_DeleteAbstractFunction(TF_AbstractFunction* func) { void TF_ExecutionContextRegisterFunction(TF_ExecutionContext* ctx, TF_AbstractFunction* func, TF_Status* s) { - unwrap(ctx)->RegisterFunction(unwrap(func), s); + Set_TF_Status_from_Status(s, unwrap(ctx)->RegisterFunction(unwrap(func))); } diff --git a/tensorflow/c/eager/c_api_unified_experimental.h b/tensorflow/c/eager/c_api_unified_experimental.h index 86c59a7f625..b66869b4290 100644 --- a/tensorflow/c/eager/c_api_unified_experimental.h +++ b/tensorflow/c/eager/c_api_unified_experimental.h @@ -110,7 +110,7 @@ void TF_OutputListPushBack(TF_OutputList* o, TF_AbstractTensor* tensor, // Any active tape will observe the effects of this execution. void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs, TF_AbstractTensor* const* inputs, TF_OutputList* o, - TF_ExecutionContext* ctx, TF_Status* s); + TF_Status* s); // Creates a new TF_AbstractFunction from the current tracing states in the // context. The provided `ctx` is consumed by this API call and deleted. @@ -137,7 +137,8 @@ TF_AbstractTensor* TF_CreateAbstractTensorFromEagerTensor(TFE_TensorHandle* t, TF_Status* s); TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at, TF_Status* s); -TFE_Context* TF_ExecutionContextGetTFEContext(TF_ExecutionContext*); +TFE_Context* TF_ExecutionContextGetTFEContext(TF_ExecutionContext*, + TF_Status* s); #ifdef __cplusplus } /* end extern "C" */ diff --git a/tensorflow/c/eager/c_api_unified_experimental_eager.cc b/tensorflow/c/eager/c_api_unified_experimental_eager.cc index cf8cf845834..986b48ff8f2 100644 --- a/tensorflow/c/eager/c_api_unified_experimental_eager.cc +++ b/tensorflow/c/eager/c_api_unified_experimental_eager.cc @@ -15,180 +15,68 @@ limitations under the License. #include <vector> +#include "tensorflow/c/eager/abstract_context.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api_unified_experimental.h" #include "tensorflow/c/eager/c_api_unified_experimental_internal.h" -#include "tensorflow/c/tf_datatype.h" +#include "tensorflow/c/eager/immediate_execution_context.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/eager/tfe_context_internal.h" +#include "tensorflow/c/eager/tfe_tensorhandle_internal.h" #include "tensorflow/c/tf_status.h" -#include "tensorflow/core/lib/gtl/inlined_vector.h" -#include "tensorflow/core/platform/casts.h" +#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" #include "tensorflow/core/platform/strcat.h" -#include "tensorflow/core/platform/types.h" - -using tensorflow::string; - -namespace tensorflow { -namespace internal { - -// Simple wrapper over a TFE_TensorHandle -struct EagerTensor : public AbstractTensor { - TFE_TensorHandle* t = nullptr; - EagerTensor() : AbstractTensor(kKind) {} - explicit EagerTensor(TFE_TensorHandle* t) : AbstractTensor(kKind), t(t) {} - ~EagerTensor() override { TFE_DeleteTensorHandle(t); } - static constexpr AbstractTensorKind kKind = kEagerTensor; -}; - -// Simple wrapper over a TFE_Op -class EagerOp : public AbstractOp { - public: - explicit EagerOp(TFE_Context* ctx) : AbstractOp(kKind), ctx_(ctx) {} - void SetOpType(const char* const op_type, TF_Status* s) override { - op_ = TFE_NewOp(ctx_, op_type, s); - } - void SetOpName(const char* const op_name, TF_Status* s) override { - // Name is ignored in eager mode. - } - void SetAttrType(const char* const attr_name, TF_DataType value, - TF_Status* s) override { - if (op_ == nullptr) { - TF_SetStatus(s, TF_FAILED_PRECONDITION, - "op_type must be specified before specifying attrs."); - return; - } - TFE_OpSetAttrType(op_, attr_name, value); - } - - ~EagerOp() override { TFE_DeleteOp(op_); } - static constexpr AbstractOpKind kKind = kEagerOp; - - private: - friend class EagerContext; // For access to op_. - TFE_Op* op_ = nullptr; - TFE_Context* ctx_; -}; - -// Wraps a TFE_Context and dispatch EagerOp with EagerTensor inputs. -class EagerContext : public ExecutionContext { - public: - EagerContext() : ExecutionContext(kKind) {} - - void Build(TFE_ContextOptions* options, TF_Status* status) { - eager_ctx_ = TFE_NewContext(options, status); - } - - AbstractOp* CreateOperation() override { - // TODO(srbs): Should the lifetime of this op be tied to the context. - return new EagerOp(eager_ctx_); - } - - void ExecuteOperation(AbstractOp* op, int num_inputs, - AbstractTensor* const* inputs, OutputList* o, - TF_Status* s) override { - auto* eager_op = dyncast<EagerOp>(op); - if (eager_op == nullptr) { - TF_SetStatus(s, TF_INVALID_ARGUMENT, - "Unable to cast AbstractOp to TF_EagerOp."); - return; - } - auto* tfe_op = eager_op->op_; - if (TF_GetCode(s) != TF_OK) return; - for (int i = 0; i < num_inputs; ++i) { - auto* eager_tensor = dyncast<const EagerTensor>(inputs[i]); - if (!eager_tensor) { - TF_SetStatus(s, TF_INVALID_ARGUMENT, "Not an eager tensor."); - return; - } - TFE_OpAddInput(tfe_op, eager_tensor->t, s); - if (TF_GetCode(s) != TF_OK) return; - } - if (o->expected_num_outputs == -1) { - string msg = - "The number of outputs must be provided in eager mode. Use " - "TF_OutputListSetNumOutputs."; - TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str()); - return; - } - tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals; - int num_retvals = o->expected_num_outputs; - retvals.resize(num_retvals); - TFE_Execute(tfe_op, retvals.data(), &num_retvals, s); - if (TF_GetCode(s) != TF_OK) { - return; - } - o->outputs.clear(); - o->outputs.reserve(num_retvals); - for (int i = 0; i < num_retvals; ++i) { - o->outputs.push_back(new EagerTensor(retvals[i])); - } - } - - AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) override { - TF_SetStatus(s, TF_INVALID_ARGUMENT, - "Can't add function parameter on an eager context."); - return nullptr; - } - AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) override { - TF_SetStatus(s, TF_INVALID_ARGUMENT, - "Can't use finalize function on an eager context."); - return nullptr; - } - - void RegisterFunction(AbstractFunction* afunc, TF_Status* s) override { - auto* func = afunc->GetTfFunction(s); - if (!func) { - return; - } - TFE_ContextAddFunction(eager_ctx_, func, s); - } - - ~EagerContext() override { TFE_DeleteContext(eager_ctx_); } - - static constexpr ExecutionContextKind kKind = kEagerContext; - - private: - friend TFE_Context* ::TF_ExecutionContextGetTFEContext( - TF_ExecutionContext* ctx); - TFE_Context* eager_ctx_; -}; - -} // namespace internal -} // namespace tensorflow // ============================================================================= // Public C API entry points // These are only the entry points specific to the Eager API. // ============================================================================= -using tensorflow::internal::dyncast; -using tensorflow::internal::unwrap; +using tensorflow::AbstractContext; +using tensorflow::AbstractTensorHandle; +using tensorflow::dyn_cast; +using tensorflow::ImmediateExecutionContext; +using tensorflow::ImmediateExecutionTensorHandle; +using tensorflow::string; +using tensorflow::unwrap; +using tensorflow::wrap; +using tensorflow::strings::StrCat; TF_ExecutionContext* TF_NewEagerExecutionContext(TFE_ContextOptions* options, TF_Status* s) { - auto* ctx = new tensorflow::internal::EagerContext(); - ctx->Build(options, s); - return wrap(ctx); + TFE_Context* c_ctx = TFE_NewContext(options, s); + if (TF_GetCode(s) != TF_OK) { + return nullptr; + } + return wrap(static_cast<AbstractContext*>(unwrap(c_ctx))); } TF_AbstractTensor* TF_CreateAbstractTensorFromEagerTensor(TFE_TensorHandle* t, TF_Status* s) { - return wrap(new tensorflow::internal::EagerTensor(t)); + return wrap(static_cast<AbstractTensorHandle*>(unwrap(t))); } TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at, TF_Status* s) { - auto* eager_tensor = dyncast<tensorflow::internal::EagerTensor>(unwrap(at)); - if (!eager_tensor) { - string msg = tensorflow::strings::StrCat("Not an eager tensor handle.", - reinterpret_cast<uintptr_t>(at)); + auto handle = dyn_cast<ImmediateExecutionTensorHandle>(unwrap(at)); + if (!handle) { + string msg = + StrCat("Not an eager tensor handle.", reinterpret_cast<uintptr_t>(at)); TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str()); return nullptr; } - return eager_tensor->t; + return wrap(handle); } -TFE_Context* TF_ExecutionContextGetTFEContext(TF_ExecutionContext* ctx) { - auto* eager_ctx = dyncast<tensorflow::internal::EagerContext>(unwrap(ctx)); - if (!eager_ctx) return nullptr; - return eager_ctx->eager_ctx_; +TFE_Context* TF_ExecutionContextGetTFEContext(TF_ExecutionContext* ctx, + TF_Status* s) { + auto imm_ctx = dyn_cast<ImmediateExecutionContext>(unwrap(ctx)); + if (!imm_ctx) { + string msg = + StrCat("Not an eager context.", reinterpret_cast<uintptr_t>(ctx)); + TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str()); + return nullptr; + } + return wrap(imm_ctx); } diff --git a/tensorflow/c/eager/c_api_unified_experimental_graph.cc b/tensorflow/c/eager/c_api_unified_experimental_graph.cc index dd5a95b3526..bda5e163a50 100644 --- a/tensorflow/c/eager/c_api_unified_experimental_graph.cc +++ b/tensorflow/c/eager/c_api_unified_experimental_graph.cc @@ -18,77 +18,198 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "tensorflow/c/c_api.h" +#include "tensorflow/c/eager/abstract_context.h" #include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/c_api_unified_experimental.h" #include "tensorflow/c/eager/c_api_unified_experimental_internal.h" #include "tensorflow/c/tf_datatype.h" #include "tensorflow/c/tf_status.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/strcat.h" #include "tensorflow/core/platform/types.h" +using tensorflow::dyn_cast; using tensorflow::string; namespace tensorflow { -namespace internal { +namespace tracing { +namespace graph { class GraphContext; +class GraphOperation; +class GraphTensor; // GraphTensor wraps a `TF_Output`, i.e. a pointer to TF_Operation and the index // into the list of outputs for the operation. -struct GraphTensor : public AbstractTensor { - TF_Output output{}; - GraphContext* ctx = nullptr; - GraphTensor() : AbstractTensor(kKind) {} - GraphTensor(TF_Output output, GraphContext* ctx) - : AbstractTensor(kKind), output(output), ctx(ctx) {} - static constexpr AbstractTensorKind kKind = kGraphTensor; +class GraphTensor : public TracingTensorHandle { + public: + explicit GraphTensor(TF_Output output) + : TracingTensorHandle(kGraph), output_(output) {} + void Release() override { delete this; } + TF_Output output_; + + // For LLVM style RTTI. + static bool classof(const AbstractTensorHandle* ptr) { + return ptr->getKind() == kGraph; + } }; -// GraphOp wraps and populate a TF_OperationDescription. -class GraphOp : public AbstractOp { +// GraphOperation wraps and populates a TF_OperationDescription. +class GraphOperation : public TracingOperation { public: - explicit GraphOp(TF_Graph* g) : AbstractOp(kKind), g_(g) {} - void SetOpType(const char* const op_type, TF_Status* s) override { + explicit GraphOperation(TF_Graph* g) : TracingOperation(kGraph), g_(g) {} + void Release() override { delete this; } + Status Reset(const char* op, const char* raw_device_name) override { if (op_) { - TF_SetStatus( - s, TF_FAILED_PRECONDITION, - strings::StrCat("SetOpType called on already built op.").c_str()); - return; + return errors::FailedPrecondition("Reset called on already built op."); } - if (op_name_ != nullptr) { - op_.reset(TF_NewOperation(g_, op_type, op_name_)); - op_name_ = nullptr; - } else { - op_type_ = op_type; + if (raw_device_name) { + device_name_ = raw_device_name; } + op_type_ = op; + return Status::OK(); } - void SetOpName(const char* const op_name, TF_Status* s) override { + Status SetOpName(const char* const op_name) override { if (op_) { - TF_SetStatus( - s, TF_FAILED_PRECONDITION, - strings::StrCat("SetOpName called on already built op.").c_str()); - return; + return errors::FailedPrecondition( + "SetOpName called on already built op."); } - if (op_type_ != nullptr) { - op_.reset(TF_NewOperation(g_, op_type_, op_name)); - op_type_ = nullptr; - } else { - op_name_ = op_name; + if (op_type_.empty()) { + return errors::FailedPrecondition( + "GraphOperation::Reset must be called before calling SetOpName."); } + op_.reset(TF_NewOperation(g_, op_type_.c_str(), op_name)); + return Status::OK(); } - void SetAttrType(const char* const attr_name, TF_DataType value, - TF_Status* s) override { - if (!op_) { - TF_SetStatus( - s, TF_FAILED_PRECONDITION, - "op_type and op_name must be specified before specifying attrs."); - return; - } - TF_SetAttrType(op_.get(), attr_name, value); - } - ~GraphOp() override {} + const string& Name() const override { return op_type_; } + const string& DeviceName() const override { return device_name_; } - static constexpr AbstractOpKind kKind = kGraphOp; + Status SetDeviceName(const char* name) override { + // TODO(srbs): Implement this. + device_name_ = name; + return Status::OK(); + } + + Status AddInput(AbstractTensorHandle* input) override { + GraphTensor* t = dyn_cast<GraphTensor>(input); + if (!t) { + return tensorflow::errors::InvalidArgument( + "Unable to cast input to GraphTensor"); + } + TF_AddInput(op_.get(), t->output_); + return Status::OK(); + } + Status AddInputList(absl::Span<AbstractTensorHandle*> inputs) override { + return tensorflow::errors::Unimplemented( + "AddInputList has not been implemented yet."); + } + Status Execute(absl::Span<AbstractTensorHandle*> retvals, + int* num_retvals) override { + auto* tf_opdesc = op_.release(); + if (tf_opdesc == nullptr) { + return errors::InvalidArgument("AbstractOp is incomplete."); + } + TF_Status* s = TF_NewStatus(); + auto* operation = TF_FinishOperation(tf_opdesc, s); + TF_RETURN_IF_ERROR(StatusFromTF_Status(s)); + TF_DeleteStatus(s); + *num_retvals = TF_OperationNumOutputs(operation); + for (int i = 0; i < *num_retvals; ++i) { + retvals[i] = new GraphTensor({operation, i}); + } + return Status::OK(); + } + + Status SetAttrString(const char* attr_name, const char* data, + size_t length) override { + return tensorflow::errors::Unimplemented( + "SetAttrString has not been implemented yet."); + } + Status SetAttrInt(const char* attr_name, int64_t value) override { + return tensorflow::errors::Unimplemented( + "SetAttrInt has not been implemented yet."); + } + Status SetAttrFloat(const char* attr_name, float value) override { + return tensorflow::errors::Unimplemented( + "SetAttrFloat has not been implemented yet."); + } + Status SetAttrBool(const char* attr_name, bool value) override { + return tensorflow::errors::Unimplemented( + "SetAttrBool has not been implemented yet."); + } + Status SetAttrType(const char* const attr_name, DataType value) override { + if (!op_) { + return Status( + error::Code::FAILED_PRECONDITION, + "op_type and op_name must be specified before specifying attrs."); + } + op_->node_builder.Attr(attr_name, value); + return Status::OK(); + } + Status SetAttrShape(const char* attr_name, const int64_t* dims, + const int num_dims) override { + return tensorflow::errors::Unimplemented( + "SetAttrShape has not been implemented yet."); + } + Status SetAttrFunction(const char* attr_name, + const AbstractOperation* value) override { + return tensorflow::errors::Unimplemented( + "SetAttrFunction has not been implemented yet."); + } + Status SetAttrFunctionName(const char* attr_name, const char* value, + size_t length) override { + return tensorflow::errors::Unimplemented( + "SetAttrFunctionName has not been implemented yet."); + } + Status SetAttrTensor(const char* attr_name, + AbstractTensorInterface* tensor) override { + return tensorflow::errors::Unimplemented( + "SetAttrTensor has not been implemented yet."); + } + Status SetAttrStringList(const char* attr_name, const void* const* values, + const size_t* lengths, int num_values) override { + return tensorflow::errors::Unimplemented( + "SetAttrStringList has not been implemented yet."); + } + Status SetAttrFloatList(const char* attr_name, const float* values, + int num_values) override { + return tensorflow::errors::Unimplemented( + "SetAttrFloatList has not been implemented yet."); + } + Status SetAttrIntList(const char* attr_name, const int64_t* values, + int num_values) override { + return tensorflow::errors::Unimplemented( + "SetAttrIntList has not been implemented yet."); + } + Status SetAttrTypeList(const char* attr_name, const DataType* values, + int num_values) override { + return tensorflow::errors::Unimplemented( + "SetAttrTypeList has not been implemented yet."); + } + Status SetAttrBoolList(const char* attr_name, const unsigned char* values, + int num_values) override { + return tensorflow::errors::Unimplemented( + "SetAttrBoolList has not been implemented yet."); + } + Status SetAttrShapeList(const char* attr_name, const int64_t** dims, + const int* num_dims, int num_values) override { + return tensorflow::errors::Unimplemented( + "SetAttrShapeList has not been implemented yet."); + } + Status SetAttrFunctionList( + const char* attr_name, + absl::Span<const AbstractOperation*> values) override { + return tensorflow::errors::Unimplemented( + "SetAttrFunctionList has not been implemented yet."); + } + // For LLVM style RTTI. + static bool classof(const AbstractOperation* ptr) { + return ptr->getKind() == kGraph; + } + ~GraphOperation() override {} private: friend class GraphContext; // For access to op_. @@ -96,123 +217,109 @@ class GraphOp : public AbstractOp { std::unique_ptr<TF_OperationDescription> op_; // Hold `op_type` and `op_name` till both are available since we need both // to build a graph operation. - const char* op_type_ = nullptr; + string op_type_; const char* op_name_ = nullptr; + // TODO(srbs): Use this. + string device_name_; }; // GraphFunction is a thin wrapper over a TF_Function. struct GraphFunction : public AbstractFunction { TF_Function* func = nullptr; - GraphFunction() : AbstractFunction(kKind) {} + GraphFunction() : AbstractFunction(kGraph) {} explicit GraphFunction(TF_Function* func) - : AbstractFunction(kKind), func(func) {} + : AbstractFunction(kGraph), func(func) {} ~GraphFunction() override { if (func) TF_DeleteFunction(func); } - TF_Function* GetTfFunction(TF_Status* s) override { return func; } + Status GetFunctionDef(FunctionDef** fdef) override { + *fdef = &func->fdef; + return Status::OK(); + } - static constexpr AbstractFunctionKind kKind = kGraphFunc; + // For LLVM style RTTI. + static bool classof(const AbstractFunction* ptr) { + return ptr->getKind() == kGraph; + } }; // GraphContext wraps a TF_Graph modeling a single function and manages the // "execution" of operation, i.e. adding them to the function. -class GraphContext : public ExecutionContext { +class GraphContext : public TracingContext { public: explicit GraphContext(const char* name) - : ExecutionContext(kKind), + : TracingContext(kGraph), graph_(new TF_Graph(), TF_DeleteGraph), name_(name) {} - AbstractOp* CreateOperation() override { - // TODO(srbs): Should the lifetime of this op be tied to the context. - return new GraphOp(graph_.get()); + void Release() override { delete this; } + + TracingOperation* CreateOperation() override { + return new GraphOperation(graph_.get()); } - void ExecuteOperation(AbstractOp* op, int num_inputs, - AbstractTensor* const* inputs, OutputList* o, - TF_Status* s) override { - auto* graph_op = dyncast<GraphOp>(op); - if (graph_op == nullptr) { - TF_SetStatus(s, TF_INVALID_ARGUMENT, - "Unable to cast AbstractOp to TF_GraphOp."); - return; + Status AddParameter(DataType dtype, TracingTensorHandle** output) override { + auto operation = CreateOperation(); + TF_RETURN_IF_ERROR(operation->Reset("Placeholder", nullptr)); + TF_RETURN_IF_ERROR( + operation->SetOpName(absl::StrCat("_input_", inputs_.size()).c_str())); + TF_RETURN_IF_ERROR(operation->SetAttrType("dtype", dtype)); + int num_outputs = 1; + std::vector<AbstractTensorHandle*> outputs(num_outputs); + TF_RETURN_IF_ERROR(operation->Execute( + absl::Span<AbstractTensorHandle*>(outputs), &num_outputs)); + + if (num_outputs != 1) { + return errors::Internal("Expected 1 output but found ", num_outputs); } - auto* tf_opdesc = graph_op->op_.release(); - if (tf_opdesc == nullptr) { - TF_SetStatus(s, TF_INVALID_ARGUMENT, "AbstractOp is incomplete."); - return; - } - for (int i = 0; i < num_inputs; ++i) { - auto* graph_tensor = dyncast<GraphTensor>(inputs[i]); - if (!graph_tensor) { - TF_SetStatus(s, TF_INVALID_ARGUMENT, - "Capturing eager tensors is not supported yet."); - return; - } else { - if (graph_tensor->ctx != this) { - TF_SetStatus( - s, TF_INVALID_ARGUMENT, - "Capturing tensors from other graphs is not supported yet."); - return; - } - TF_AddInput(tf_opdesc, graph_tensor->output); - } - } - auto* operation = TF_FinishOperation(tf_opdesc, s); - // TF_FinishOperation deletes `tf_opdesc` so clear its reference. - graph_op->op_ = nullptr; - if (TF_GetCode(s) != TF_OK) return; - int num_outputs = TF_OperationNumOutputs(operation); - o->outputs.clear(); - o->outputs.reserve(num_outputs); - for (int i = 0; i < num_outputs; ++i) { - o->outputs.push_back(new GraphTensor({operation, i}, this)); + auto* t = dyn_cast<GraphTensor>(outputs[0]); + if (!t) { + return tensorflow::errors::InvalidArgument( + "Unable to cast input to GraphTensor"); } + inputs_.push_back(t->output_); + *output = tensorflow::down_cast<TracingTensorHandle*>(outputs[0]); + return Status::OK(); } - AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) override { - TF_OperationDescription* opdesc = - TF_NewOperation(graph_.get(), "Placeholder", - absl::StrCat("_input_", inputs_.size()).c_str()); - TF_SetAttrType(opdesc, "dtype", dtype); - auto* operation = TF_FinishOperation(opdesc, s); - if (!s->status.ok()) return nullptr; - - inputs_.push_back(TF_Output{operation, 0}); - return new GraphTensor(inputs_.back(), this); - } - - AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) override { + Status Finalize(OutputList* outputs, AbstractFunction** f) override { std::unique_ptr<GraphFunction> func(new GraphFunction); std::vector<TF_Output> graph_outputs; graph_outputs.reserve(outputs->outputs.size()); - for (AbstractTensor* abstract_output : outputs->outputs) { - GraphTensor* output = dyncast<GraphTensor>(abstract_output); + for (auto* abstract_output : outputs->outputs) { + GraphTensor* output = dyn_cast<GraphTensor>(abstract_output); if (!output) { - TF_SetStatus(s, TF_UNIMPLEMENTED, - "Returning a non-graph tensor from a function has not " - "been implemented yet."); - return nullptr; + return errors::Unimplemented( + "Returning a non-graph tensor from a function has not " + "been implemented yet."); } - graph_outputs.push_back(output->output); + graph_outputs.push_back(output->output_); } + auto s = TF_NewStatus(); func->func = TF_GraphToFunction( graph_.get(), name_, 0, -1, nullptr, inputs_.size(), inputs_.data(), graph_outputs.size(), graph_outputs.data(), nullptr, nullptr, name_, s); - if (TF_GetCode(s) != TF_OK) return nullptr; - return func.release(); + TF_RETURN_IF_ERROR(StatusFromTF_Status(s)); + TF_DeleteStatus(s); + *f = func.release(); + return Status::OK(); } - void RegisterFunction(AbstractFunction* func, TF_Status* s) override { - TF_SetStatus(s, TF_UNIMPLEMENTED, - "Registering graph functions has not been implemented yet."); + Status RegisterFunction(AbstractFunction* func) override { + return errors::Unimplemented( + "Registering graph functions has not been implemented yet."); } - ~GraphContext() override {} - - static constexpr ExecutionContextKind kKind = kGraphContext; + Status RemoveFunction(const string& func) override { + return errors::Unimplemented( + "GraphContext::RemoveFunction has not been implemented yet."); + } + // For LLVM style RTTI. + static bool classof(const AbstractContext* ptr) { + return ptr->getKind() == kGraph; + } private: std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> graph_; @@ -220,7 +327,7 @@ class GraphContext : public ExecutionContext { const char* name_; }; -static ExecutionContext* GraphTracingFactory(const char* name, TF_Status* s) { +static TracingContext* GraphTracingFactory(const char* name, TF_Status* s) { return new GraphContext(name); } @@ -231,5 +338,6 @@ static bool register_tracing = [] { return true; }(); -} // namespace internal +} // namespace graph +} // namespace tracing } // namespace tensorflow diff --git a/tensorflow/c/eager/c_api_unified_experimental_internal.h b/tensorflow/c/eager/c_api_unified_experimental_internal.h index 8fc696f0f2f..5e09d4a6024 100644 --- a/tensorflow/c/eager/c_api_unified_experimental_internal.h +++ b/tensorflow/c/eager/c_api_unified_experimental_internal.h @@ -19,6 +19,10 @@ limitations under the License. #include <vector> #include "tensorflow/c/c_api.h" +#include "tensorflow/c/conversion_macros.h" +#include "tensorflow/c/eager/abstract_context.h" +#include "tensorflow/c/eager/abstract_operation.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" #include "tensorflow/c/eager/c_api_unified_experimental.h" #include "tensorflow/c/tf_datatype.h" #include "tensorflow/c/tf_status.h" @@ -26,7 +30,14 @@ limitations under the License. #include "tensorflow/core/platform/types.h" namespace tensorflow { -namespace internal { + +// Represents the results of the execution of an operation. +struct OutputList { + std::vector<AbstractTensorHandle*> outputs; + int expected_num_outputs = -1; +}; + +namespace tracing { // ============================================================================= // Implementation detail for the unified execution APIs for Eager and tracing @@ -37,165 +48,75 @@ namespace internal { // `c_api_unified_experimental.h` header. // ============================================================================= -// We can't depend on C++ rtti, but we still want to be able to have a safe -// dynamic_cast to provide diagnostics to the user when the API is misused. -// Instead we model RTTI by listing all the possible subclasses for each -// abstract base. Each subclass initializes the base class with the right -// `kind`, which allows an equivalent to `std::dynamic_cast` provided by this -// utility. -template <typename T, typename S> -T* dyncast(S source) { - if (source->getKind() != T::kKind) { - return nullptr; - } - return tensorflow::down_cast<T*>(source); -} - -// Represents either an EagerTensor or a GraphTensor. +// Represents either a MlirTensor or a GraphTensor. // This base class does not expose any public methods other than to distinguish // which subclass it actually is. The user is responsible to use the right -// type of AbstractTensor in their context (do not pass an EagerTensor to a +// type of AbstractTensor in their context (do not pass an MlirTensor to a // GraphContext and vice-versa). -class AbstractTensor { +class TracingTensorHandle : public AbstractTensorHandle { protected: - enum AbstractTensorKind { kMlirTensor, kGraphTensor, kEagerTensor }; - explicit AbstractTensor(AbstractTensorKind kind) : kind_(kind) {} + explicit TracingTensorHandle(AbstractTensorHandleKind kind) + : AbstractTensorHandle(kind) {} public: - // Returns which subclass is this instance of. - AbstractTensorKind getKind() const { return kind_; } - virtual ~AbstractTensor() = default; - - private: - const AbstractTensorKind kind_; -}; - -// Represents the results of the execution of an operation. -struct OutputList { - std::vector<AbstractTensor*> outputs; - int expected_num_outputs = -1; -}; - -// Holds the result of tracing a function. -class AbstractFunction { - protected: - enum AbstractFunctionKind { kGraphFunc }; - explicit AbstractFunction(AbstractFunctionKind kind) : kind_(kind) {} - - public: - // Returns which subclass is this instance of. - AbstractFunctionKind getKind() const { return kind_; } - virtual ~AbstractFunction() = default; - - // Temporary API till we figure the right abstraction for AbstractFunction. - // At the moment both Eager and Graph needs access to a "TF_Function" object. - virtual TF_Function* GetTfFunction(TF_Status* s) = 0; - - private: - const AbstractFunctionKind kind_; + // For LLVM style RTTI. + static bool classof(const AbstractTensorHandle* ptr) { + return ptr->getKind() == kGraph || ptr->getKind() == kMlir; + } }; // An abstract operation describes an operation by its type, name, and // attributes. It can be "executed" by the context with some input tensors. // It is allowed to reusing the same abstract operation for multiple execution // on a given context, with the same or different input tensors. -class AbstractOp { +class TracingOperation : public AbstractOperation { protected: - enum AbstractOpKind { kMlirOp, kGraphOp, kEagerOp }; - explicit AbstractOp(AbstractOpKind kind) : kind_(kind) {} + explicit TracingOperation(AbstractOperationKind kind) + : AbstractOperation(kind) {} public: - // Returns which subclass is this instance of. - AbstractOpKind getKind() const { return kind_; } - virtual ~AbstractOp() = default; - - // Sets the type of the operation (for example `AddV2`). - virtual void SetOpType(const char* op_type, TF_Status* s) = 0; - // Sets the name of the operation: this is an optional identifier that is // not intended to carry semantics and preserved/propagated without // guarantees. - virtual void SetOpName(const char* op_name, TF_Status* s) = 0; + virtual Status SetOpName(const char* op_name) = 0; - // Add a `TypeAttribute` on the operation. - virtual void SetAttrType(const char* attr_name, TF_DataType value, - TF_Status* s) = 0; - - private: - const AbstractOpKind kind_; + // For LLVM style RTTI. + static bool classof(const AbstractOperation* ptr) { + return ptr->getKind() == kGraph || ptr->getKind() == kMlir; + } }; // This holds the context for the execution: dispatching operations either to an -// eager implementation or to a graph implementation. -struct ExecutionContext { +// MLIR implementation or to a graph implementation. +class TracingContext : public AbstractContext { protected: - enum ExecutionContextKind { kMlirContext, kGraphContext, kEagerContext }; - explicit ExecutionContext(ExecutionContextKind kind) : k(kind) {} + explicit TracingContext(AbstractContextKind kind) : AbstractContext(kind) {} public: - // Returns which subclass is this instance of. - ExecutionContextKind getKind() const { return k; } - virtual ~ExecutionContext() = default; - - // Executes the operation on the provided inputs and populate the OutputList - // with the results. The input tensors must match the current context. - // The effect of "executing" an operation depends on the context: in an Eager - // context it will dispatch it to the runtime for execution, while in a - // tracing context it will add the operation to the current function. - virtual void ExecuteOperation(AbstractOp* op, int num_inputs, - AbstractTensor* const* inputs, OutputList* o, - TF_Status* s) = 0; - - // Creates an empty AbstractOperation suitable to use with this context. - virtual AbstractOp* CreateOperation() = 0; - // Add a function parameter and return the corresponding tensor. - // This is only valid with an ExecutionContext obtained from a TracingContext, - // it'll always error out with an eager context. - virtual AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) = 0; + virtual Status AddParameter(DataType dtype, TracingTensorHandle**) = 0; // Finalize this context and make a function out of it. The context is in a // invalid state after this call and must be destroyed. - // This is only valid with an ExecutionContext obtained from a TracingContext, - // it'll always error out with an eager context. - virtual AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) = 0; + virtual Status Finalize(OutputList* outputs, AbstractFunction**) = 0; - // Registers a functions with this context, after this the function is - // available to be called/referenced by its name in this context. - virtual void RegisterFunction(AbstractFunction* func, TF_Status* s) = 0; - - private: - const ExecutionContextKind k; + // For LLVM style RTTI. + static bool classof(const AbstractContext* ptr) { + return ptr->getKind() == kGraph || ptr->getKind() == kMlir; + } }; -typedef ExecutionContext* (*FactoryFunction)(const char* fn_name, TF_Status*); +typedef TracingContext* (*FactoryFunction)(const char* fn_name, TF_Status*); void SetDefaultTracingEngine(const char* name); void RegisterTracingEngineFactory(const ::tensorflow::string& name, FactoryFunction factory); +} // namespace tracing -// Create utilities to wrap/unwrap: this convert from the C opaque types to the -// C++ implementation, and back. -#define MAKE_WRAP_UNWRAP(C_TYPEDEF, CPP_CLASS) \ - static inline CPP_CLASS* const& unwrap(C_TYPEDEF* const& o) { \ - return reinterpret_cast<CPP_CLASS* const&>(o); \ - } \ - static inline const CPP_CLASS* const& unwrap(const C_TYPEDEF* const& o) { \ - return reinterpret_cast<const CPP_CLASS* const&>(o); \ - } \ - static inline C_TYPEDEF* const& wrap(CPP_CLASS* const& o) { \ - return reinterpret_cast<C_TYPEDEF* const&>(o); \ - } \ - static inline const C_TYPEDEF* const& wrap(const CPP_CLASS* const& o) { \ - return reinterpret_cast<const C_TYPEDEF* const&>(o); \ - } - -MAKE_WRAP_UNWRAP(TF_ExecutionContext, ExecutionContext) -MAKE_WRAP_UNWRAP(TF_AbstractFunction, AbstractFunction) -MAKE_WRAP_UNWRAP(TF_AbstractTensor, AbstractTensor) -MAKE_WRAP_UNWRAP(TF_AbstractOp, AbstractOp) -MAKE_WRAP_UNWRAP(TF_OutputList, OutputList) - -} // namespace internal +DEFINE_CONVERSION_FUNCTIONS(AbstractContext, TF_ExecutionContext) +DEFINE_CONVERSION_FUNCTIONS(AbstractTensorHandle, TF_AbstractTensor) +DEFINE_CONVERSION_FUNCTIONS(AbstractFunction, TF_AbstractFunction) +DEFINE_CONVERSION_FUNCTIONS(AbstractOperation, TF_AbstractOp) +DEFINE_CONVERSION_FUNCTIONS(OutputList, TF_OutputList) } // namespace tensorflow #endif // TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_INTERNAL_H_ diff --git a/tensorflow/c/eager/c_api_unified_experimental_test.cc b/tensorflow/c/eager/c_api_unified_experimental_test.cc index 24d170f2f99..221ed356645 100644 --- a/tensorflow/c/eager/c_api_unified_experimental_test.cc +++ b/tensorflow/c/eager/c_api_unified_experimental_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include <memory> #include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/c_api_test_util.h" #include "tensorflow/c/tf_datatype.h" #include "tensorflow/c/tf_status.h" @@ -29,15 +30,19 @@ using tensorflow::string; namespace tensorflow { namespace { -class UnifiedCAPI : public ::testing::TestWithParam<const char*> { +class UnifiedCAPI + : public ::testing::TestWithParam<std::tuple<const char*, bool>> { protected: - void SetUp() override { TF_SetTracingImplementation(GetParam()); } + void SetUp() override { + TF_SetTracingImplementation(std::get<0>(GetParam())); + } }; TEST_P(UnifiedCAPI, TestBasicEager) { std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status( TF_NewStatus(), TF_DeleteStatus); TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetTfrt(opts, std::get<1>(GetParam())); TF_ExecutionContext* ctx = TF_NewEagerExecutionContext(opts, status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); TFE_DeleteContextOptions(opts); @@ -45,7 +50,8 @@ TEST_P(UnifiedCAPI, TestBasicEager) { ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); // Build an abstract input tensor. - TFE_Context* eager_ctx = TF_ExecutionContextGetTFEContext(ctx); + TFE_Context* eager_ctx = TF_ExecutionContextGetTFEContext(ctx, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); TFE_TensorHandle* t = TestScalarTensorHandle(eager_ctx, 2.0f); TF_AbstractTensor* at = TF_CreateAbstractTensorFromEagerTensor(t, status.get()); @@ -63,7 +69,7 @@ TEST_P(UnifiedCAPI, TestBasicEager) { ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); // Execute. - TF_ExecuteOperation(op, 2, inputs, o, ctx, status.get()); + TF_ExecuteOperation(op, 2, inputs, o, status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); // Clean up operation and inputs. @@ -109,9 +115,11 @@ TEST_P(UnifiedCAPI, TestBasicGraph) { // Build inputs and outputs. TF_AbstractTensor* inputs[2] = {placeholder_t, placeholder_t}; TF_OutputList* add_outputs = TF_NewOutputList(); + TF_OutputListSetNumOutputs(add_outputs, 1, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); // Execute. - TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, status.get()); + TF_ExecuteOperation(add_op, 2, inputs, add_outputs, status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); // Clean up operation and inputs. @@ -123,6 +131,7 @@ TEST_P(UnifiedCAPI, TestBasicGraph) { // Build eager context. TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetTfrt(opts, std::get<1>(GetParam())); TF_ExecutionContext* eager_execution_ctx = TF_NewEagerExecutionContext(opts, status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); @@ -137,16 +146,14 @@ TEST_P(UnifiedCAPI, TestBasicGraph) { // Build an abstract input tensor. TFE_Context* eager_ctx = - TF_ExecutionContextGetTFEContext(eager_execution_ctx); + TF_ExecutionContextGetTFEContext(eager_execution_ctx, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, 2.0f); TF_AbstractTensor* input_t = TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - TF_OutputListSetNumOutputs(add_outputs, 1, status.get()); - ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - TF_ExecuteOperation(fn_op, 1, &input_t, add_outputs, eager_execution_ctx, - status.get()); + TF_ExecuteOperation(fn_op, 1, &input_t, add_outputs, status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); ASSERT_EQ(1, TF_OutputListNumOutputs(add_outputs)); @@ -195,8 +202,10 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraph) { ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); TF_AbstractTensor* inputs[2] = {arg0, arg1}; TF_OutputList* add_outputs = TF_NewOutputList(); + TF_OutputListSetNumOutputs(add_outputs, 1, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); // Trace the operation now (create a node in the graph). - TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, s); + TF_ExecuteOperation(add_op, 2, inputs, add_outputs, s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); TF_DeleteAbstractOp(add_op); // Extract the resulting tensor. @@ -215,8 +224,10 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraph) { ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); TF_AbstractTensor* inputs[2] = {arg1, arg1}; TF_OutputList* add_outputs = TF_NewOutputList(); + TF_OutputListSetNumOutputs(add_outputs, 1, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); // Trace the operation now (create a node in the graph). - TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, s); + TF_ExecuteOperation(add_op, 2, inputs, add_outputs, s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); TF_DeleteAbstractOp(add_op); // Extract the resulting tensor. @@ -256,6 +267,7 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraph) { // Build eager context. TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetTfrt(opts, std::get<1>(GetParam())); TF_ExecutionContext* eager_execution_ctx = TF_NewEagerExecutionContext(opts, s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); @@ -273,7 +285,8 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraph) { std::vector<TF_AbstractTensor*> func_args; { TFE_Context* eager_ctx = - TF_ExecutionContextGetTFEContext(eager_execution_ctx); + TF_ExecutionContextGetTFEContext(eager_execution_ctx, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, 2.0f); func_args.push_back(TF_CreateAbstractTensorFromEagerTensor(input_eager, s)); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); @@ -286,7 +299,7 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraph) { TF_OutputListSetNumOutputs(func_outputs, 2, s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); TF_ExecuteOperation(fn_op, func_args.size(), func_args.data(), func_outputs, - eager_execution_ctx, s); + s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); TF_DeleteAbstractOp(fn_op); for (TF_AbstractTensor* t : func_args) TF_DeleteAbstractTensor(t); @@ -314,20 +327,21 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraph) { TF_DeleteAbstractFunction(func); } -TEST(UnifiedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) { +TEST_P(UnifiedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) { std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status( TF_NewStatus(), TF_DeleteStatus); TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetTfrt(opts, std::get<1>(GetParam())); TF_ExecutionContext* ctx = TF_NewEagerExecutionContext(opts, status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); TFE_DeleteContextOptions(opts); - TF_AbstractFunction* func = TF_FinalizeFunction(ctx, nullptr, status.get()); - ASSERT_EQ(nullptr, func); + TF_AbstractFunction* f = TF_FinalizeFunction(ctx, nullptr, status.get()); + ASSERT_EQ(nullptr, f); ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get())); } -TEST_P(UnifiedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) { +TEST_P(UnifiedCAPI, TF_AbstractOpSetOpTypeAfterFinishingOpBuildingRaises) { std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status( TF_NewStatus(), TF_DeleteStatus); TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get()); @@ -348,7 +362,7 @@ TEST_P(UnifiedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) { TF_DeleteExecutionContext(graph_ctx); } -TEST_P(UnifiedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) { +TEST_P(UnifiedCAPI, TF_AbstractOpSetOpNameAfterFinishingOpBuildingRaises) { std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status( TF_NewStatus(), TF_DeleteStatus); TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get()); @@ -369,116 +383,44 @@ TEST_P(UnifiedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) { TF_DeleteExecutionContext(graph_ctx); } -TEST_P(UnifiedCAPI, TestExecutingEagerOpInGraphModeRaises) { - // Build an Eager context. +TEST_P(UnifiedCAPI, TF_AbstractTensorGetEagerTensorOnGraphTensorRaises) { std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status( TF_NewStatus(), TF_DeleteStatus); - TFE_ContextOptions* opts = TFE_NewContextOptions(); - TF_ExecutionContext* ctx = TF_NewEagerExecutionContext(opts, status.get()); - ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - TFE_DeleteContextOptions(opts); - - ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - - // Build an Eager operation. - auto* op = TF_NewAbstractOp(ctx); - TF_AbstractOpSetOpType(op, "Add", status.get()); - ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - - // Build an abstract input tensor. - TFE_Context* eager_ctx = TF_ExecutionContextGetTFEContext(ctx); - TFE_TensorHandle* t = TestScalarTensorHandle(eager_ctx, 2.0f); - TF_AbstractTensor* at = - TF_CreateAbstractTensorFromEagerTensor(t, status.get()); - ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - - // Build inputs and outputs. - TF_AbstractTensor* inputs[2] = {at, at}; - TF_OutputList* o = TF_NewOutputList(); - TF_OutputListSetNumOutputs(o, 1, status.get()); - ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - - // Build a Graph context. - ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get()); - ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - - // Execute eager op using graph context. - TF_ExecuteOperation(op, 2, inputs, o, graph_ctx, status.get()); - ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get())); - - // Clean up operation and inputs. - TF_DeleteAbstractOp(op); - TF_DeleteAbstractTensor(at); - - TF_DeleteOutputList(o); - TF_DeleteExecutionContext(ctx); - TF_DeleteExecutionContext(graph_ctx); -} - -TEST_P(UnifiedCAPI, TestExecutingGraphOpInEagerModeRaises) { - std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status( - TF_NewStatus(), TF_DeleteStatus); - ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); // Add a placeholder to the graph. - auto* placeholder_op = TF_NewAbstractOp(graph_ctx); - TF_AbstractOpSetOpType(placeholder_op, "Placeholder", status.get()); - ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - TF_AbstractOpSetOpName(placeholder_op, "my_ph", status.get()); - ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - TF_AbstractOpSetAttrType(placeholder_op, "dtype", TF_FLOAT, status.get()); - ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - - // Build inputs and outputs. - TF_OutputList* placeholder_outputs = TF_NewOutputList(); - - // Execute. - TF_ExecuteOperation(placeholder_op, 0, nullptr, placeholder_outputs, - graph_ctx, status.get()); - ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - ASSERT_EQ(1, TF_OutputListNumOutputs(placeholder_outputs)); - TF_AbstractTensor* placeholder_t = TF_OutputListGet(placeholder_outputs, 0); - - // Delete placeholder op. - TF_DeleteAbstractOp(placeholder_op); - - // Build an abstract operation. - auto* add_op = TF_NewAbstractOp(graph_ctx); - TF_AbstractOpSetOpType(add_op, "Add", status.get()); - ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - TF_AbstractOpSetOpName(add_op, "my_add", status.get()); - ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - - // Build inputs and outputs. - TF_AbstractTensor* inputs[2] = {placeholder_t, placeholder_t}; - TF_OutputList* add_outputs = TF_NewOutputList(); - - // Build eager context. - TFE_ContextOptions* opts = TFE_NewContextOptions(); - TF_ExecutionContext* eager_execution_ctx = - TF_NewEagerExecutionContext(opts, status.get()); - ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - TFE_DeleteContextOptions(opts); - - // Execute. - TF_ExecuteOperation(add_op, 2, inputs, add_outputs, eager_execution_ctx, - status.get()); + auto placeholder_t = + TF_AddFunctionParameter(graph_ctx, TF_FLOAT, status.get()); + TF_AbstractTensorGetEagerTensor(placeholder_t, status.get()); ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get())); - // Clean up operation and inputs. TF_DeleteAbstractTensor(placeholder_t); - TF_DeleteAbstractOp(add_op); - TF_DeleteOutputList(add_outputs); - TF_DeleteOutputList(placeholder_outputs); TF_DeleteExecutionContext(graph_ctx); - TF_DeleteExecutionContext(eager_execution_ctx); } +TEST_P(UnifiedCAPI, TF_ExecutionContextGetTFEContextFromFunctionContextRaises) { + std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status( + TF_NewStatus(), TF_DeleteStatus); + TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + TF_ExecutionContextGetTFEContext(graph_ctx, status.get()); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get())); + + TF_DeleteExecutionContext(graph_ctx); +} +#ifdef PLATFORM_GOOGLE INSTANTIATE_TEST_SUITE_P(Tracing, UnifiedCAPI, - ::testing::Values("graphdef", "mlir")); + ::testing::Combine(::testing::Values("graphdef", + "mlir"), + ::testing::Values(true, false))); +#else +INSTANTIATE_TEST_SUITE_P(Tracing, UnifiedCAPI, + ::testing::Combine(::testing::Values("graphdef", + "mlir"), + ::testing::Values(false))); +#endif } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/c/BUILD b/tensorflow/compiler/mlir/tensorflow/c/BUILD index 6d8b73b758a..801e35280d6 100644 --- a/tensorflow/compiler/mlir/tensorflow/c/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/c/BUILD @@ -23,8 +23,12 @@ tf_cuda_library( copts = tf_copts() + tfe_xla_copts(), deps = [ "//tensorflow/c:c_api", + "//tensorflow/c:tensor_interface", "//tensorflow/c:tf_status_helper", "//tensorflow/c:tf_status_internal", + "//tensorflow/c/eager:abstract_context", + "//tensorflow/c/eager:abstract_operation", + "//tensorflow/c/eager:abstract_tensor_handle", "//tensorflow/c/eager:c_api", "//tensorflow/c/eager:c_api_internal", "//tensorflow/c/eager:c_api_unified_internal", @@ -35,6 +39,7 @@ tf_cuda_library( "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/lib/llvm_rtti", "//tensorflow/core/platform:errors", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", diff --git a/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc index 0e8b7fedd9b..935e87c5fa4 100644 --- a/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc +++ b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc @@ -26,14 +26,19 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/c/c_api.h" +#include "tensorflow/c/eager/abstract_context.h" +#include "tensorflow/c/eager/abstract_operation.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/c_api_unified_experimental_internal.h" +#include "tensorflow/c/tensor_interface.h" #include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status_helper.h" #include "tensorflow/c/tf_status_internal.h" @@ -47,16 +52,21 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" #include "tensorflow/core/platform/errors.h" namespace mlir { namespace TF { -using tensorflow::internal::AbstractFunction; -using tensorflow::internal::AbstractOp; -using tensorflow::internal::AbstractTensor; -using tensorflow::internal::dyncast; -using tensorflow::internal::ExecutionContext; -using tensorflow::internal::OutputList; +using tensorflow::AbstractFunction; +using tensorflow::AbstractOperation; +using tensorflow::AbstractTensorHandle; +using tensorflow::AbstractTensorInterface; +using tensorflow::dyn_cast; +using tensorflow::OutputList; +using tensorflow::string; +using tensorflow::tracing::TracingContext; +using tensorflow::tracing::TracingOperation; +using tensorflow::tracing::TracingTensorHandle; namespace { @@ -78,43 +88,104 @@ Status ConvertDataTypeToTensor(tensorflow::DataType dtype, Builder builder, return s; } -class MlirTensor : public AbstractTensor { +class MlirTensor : public TracingTensorHandle { public: - explicit MlirTensor(Value value) : AbstractTensor(kKind), value_(value) {} + explicit MlirTensor(Value value) + : TracingTensorHandle(kMlir), value_(value) {} + + void Release() override { delete this; } Value getValue() { return value_; } - static constexpr AbstractTensorKind kKind = kMlirTensor; + // For LLVM style RTTI. + static bool classof(const AbstractTensorHandle* ptr) { + return ptr->getKind() == kMlir; + } private: Value value_; }; -class MlirAbstractOp : public AbstractOp { +class MlirFunctionContext; + +class MlirAbstractOp : public TracingOperation { public: - explicit MlirAbstractOp(MLIRContext* context) - : AbstractOp(kKind), context_(context) {} + explicit MlirAbstractOp(MLIRContext* context, + MlirFunctionContext* function_context) + : TracingOperation(kMlir), + context_(context), + function_context_(function_context) {} - void SetOpType(const char* op_type, TF_Status* s) override; + void Release() override { delete this; } - void SetAttrType(const char* attr_name, TF_DataType dtype, - TF_Status* s) override; + Status Reset(const char* op, const char* raw_device_name) override; - void SetOpName(const char* const op_name, TF_Status* s) override; + const string& Name() const override; + + const string& DeviceName() const override; + + Status SetDeviceName(const char* name) override; + + Status AddInput(AbstractTensorHandle* input) override; + Status AddInputList(absl::Span<AbstractTensorHandle*> inputs) override; + Status Execute(absl::Span<AbstractTensorHandle*> retvals, + int* num_retvals) override; + + Status SetAttrString(const char* attr_name, const char* data, + size_t length) override; + Status SetAttrInt(const char* attr_name, int64_t value) override; + Status SetAttrFloat(const char* attr_name, float value) override; + Status SetAttrBool(const char* attr_name, bool value) override; + Status SetAttrType(const char* attr_name, + tensorflow::DataType dtype) override; + Status SetAttrShape(const char* attr_name, const int64_t* dims, + const int num_dims) override; + Status SetAttrFunction(const char* attr_name, + const AbstractOperation* value) override; + Status SetAttrFunctionName(const char* attr_name, const char* value, + size_t length) override; + Status SetAttrTensor(const char* attr_name, + AbstractTensorInterface* tensor) override; + Status SetAttrStringList(const char* attr_name, const void* const* values, + const size_t* lengths, int num_values) override; + Status SetAttrFloatList(const char* attr_name, const float* values, + int num_values) override; + Status SetAttrIntList(const char* attr_name, const int64_t* values, + int num_values) override; + Status SetAttrTypeList(const char* attr_name, + const tensorflow::DataType* values, + int num_values) override; + Status SetAttrBoolList(const char* attr_name, const unsigned char* values, + int num_values) override; + Status SetAttrShapeList(const char* attr_name, const int64_t** dims, + const int* num_dims, int num_values) override; + Status SetAttrFunctionList( + const char* attr_name, + absl::Span<const AbstractOperation*> values) override; + + Status SetOpName(const char* const op_name) override; MLIRContext* GetContext() { return context_; } - Type AddRef(Type type, TF_Status* s); + Status AddRef(Type type, Type* output_type); - OperationState* Create(ArrayRef<Value> operands, TF_Status* s); + Status Create(ArrayRef<Value> operands, OperationState**); - static constexpr AbstractOpKind kKind = kMlirOp; + // For LLVM style RTTI. + static bool classof(const AbstractOperation* ptr) { + return ptr->getKind() == kMlir; + } private: MLIRContext* context_; + MlirFunctionContext* function_context_; + SmallVector<Value, 8> operands_; llvm::StringMap<Attribute> attrs_; std::unique_ptr<OperationState> state_; const char* op_name_ = nullptr; + string tf_op_type_; + // TODO(srbs): Use this. + string device_name_; }; // MlirFunction is a thin wrapper over a FuncOp. @@ -122,14 +193,17 @@ class MlirFunction : public AbstractFunction { public: explicit MlirFunction(std::unique_ptr<MLIRContext> context, OwningModuleRef module, FuncOp func) - : AbstractFunction(kKind), + : AbstractFunction(kMlir), context_(std::move(context)), module_(std::move(module)), func_(func) {} - TF_Function* GetTfFunction(TF_Status* s) override; + Status GetFunctionDef(tensorflow::FunctionDef** f) override; - static constexpr AbstractFunctionKind kKind = kGraphFunc; + // For LLVM style RTTI. + static bool classof(const AbstractFunction* ptr) { + return ptr->getKind() == kMlir; + } private: std::unique_ptr<MLIRContext> context_; @@ -137,10 +211,10 @@ class MlirFunction : public AbstractFunction { FuncOp func_; }; -class MlirFunctionContext : public ExecutionContext { +class MlirFunctionContext : public TracingContext { public: explicit MlirFunctionContext(const char* name) - : ExecutionContext(kKind), + : TracingContext(kMlir), context_(std::make_unique<MLIRContext>()), builder_(context_.get()) { // TODO(aminim) figure out the location story here @@ -151,24 +225,27 @@ class MlirFunctionContext : public ExecutionContext { builder_ = OpBuilder::atBlockBegin(func_.addEntryBlock()); } - AbstractOp* CreateOperation() override { - return new MlirAbstractOp(context_.get()); + void Release() override { delete this; } + + AbstractOperation* CreateOperation() override { + return new MlirAbstractOp(context_.get(), this); } + Status AddParameter(tensorflow::DataType dtype, + TracingTensorHandle** handle) override; - void ExecuteOperation(AbstractOp* abstract_op, int num_inputs, - AbstractTensor* const* inputs, OutputList* o, - TF_Status* s) override; + Status Finalize(OutputList* outputs, AbstractFunction** f) override; - AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) override; - - AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) override; - - void RegisterFunction(AbstractFunction* func, TF_Status* s) override { - s->status = tensorflow::errors::Unimplemented( + Status RegisterFunction(AbstractFunction* func) override { + return tensorflow::errors::Unimplemented( "Registering graph functions has not been implemented yet."); } - static constexpr ExecutionContextKind kKind = kMlirContext; + Status RemoveFunction(const string& func) override { + return tensorflow::errors::Unimplemented( + "MlirFunctionContext::RemoveFunction has not been implemented yet."); + } + + Operation* CreateOperationFromState(const OperationState& state); private: std::unique_ptr<MLIRContext> context_; @@ -177,91 +254,88 @@ class MlirFunctionContext : public ExecutionContext { OwningModuleRef module_; }; -void MlirAbstractOp::SetOpType(const char* op_type, TF_Status* s) { +Status MlirAbstractOp::Reset(const char* op, const char* device_name) { if (state_) { - s->status = tensorflow::errors::FailedPrecondition( - "SetOpType called on already built op."); - return; + return tensorflow::errors::FailedPrecondition( + "Reset called on already built op."); } + tf_op_type_ = op; std::string name = "tf."; - name += op_type; + name += op; // TODO(aminim) figure out the location story here state_ = std::make_unique<OperationState>(UnknownLoc::get(context_), name); + return Status::OK(); } -void MlirAbstractOp::SetAttrType(const char* attr_name, TF_DataType dtype, - TF_Status* s) { +Status MlirAbstractOp::SetAttrType(const char* attr_name, + tensorflow::DataType dtype) { if (!state_) { - s->status = tensorflow::errors::FailedPrecondition( - "op_type must be specified before specifying attrs."); - return; + return Status(tensorflow::error::Code::FAILED_PRECONDITION, + "op_type must be specified before specifying attrs."); } Type mlir_type; Builder builder(context_); - s->status = ConvertDataTypeToTensor(static_cast<tensorflow::DataType>(dtype), - builder, &mlir_type); - if (!s->status.ok()) return; + TF_RETURN_IF_ERROR(ConvertDataTypeToTensor(dtype, builder, &mlir_type)); attrs_[attr_name] = TypeAttr::get(mlir_type); + return Status::OK(); } -void MlirAbstractOp::SetOpName(const char* const op_name, TF_Status* s) { +Status MlirAbstractOp::SetOpName(const char* const op_name) { // TODO(aminim): should we use a location? if (op_name_) { - s->status = tensorflow::errors::FailedPrecondition( + return tensorflow::errors::FailedPrecondition( "SetOpName called on already built op."); - return; } op_name_ = op_name; + return Status::OK(); } -Type MlirAbstractOp::AddRef(Type type, TF_Status* s) { +Status MlirAbstractOp::AddRef(Type type, Type* output_type) { Type elt_type = getElementTypeOrSelf(type); if (elt_type.isa<mlir::TF::TensorFlowRefType>()) { - s->status = tensorflow::errors::InvalidArgument( + return tensorflow::errors::InvalidArgument( "Requested reference to a reference type"); - return nullptr; } elt_type = TensorFlowRefType::get(elt_type); if (RankedTensorType tensor_type = type.dyn_cast<RankedTensorType>()) { - return RankedTensorType::get(tensor_type.getShape(), elt_type); + *output_type = RankedTensorType::get(tensor_type.getShape(), elt_type); } - return UnrankedTensorType::get(elt_type); + *output_type = UnrankedTensorType::get(elt_type); + return Status::OK(); } -OperationState* MlirAbstractOp::Create(ArrayRef<Value> operands, TF_Status* s) { +Status MlirAbstractOp::Create(ArrayRef<Value> operands, + OperationState** state) { state_->operands = llvm::to_vector<4>(operands); const tensorflow::OpDef* op_def; auto node_name = state_->name.getStringRef().drop_front( TensorFlowDialect::getDialectNamespace().size() + 1); - s->status = - tensorflow::OpRegistry::Global()->LookUpOpDef(node_name.str(), &op_def); - if (!s->status.ok()) return nullptr; + TF_RETURN_IF_ERROR( + tensorflow::OpRegistry::Global()->LookUpOpDef(node_name.str(), &op_def)); Builder builder(context_); // Process operands according to the op_def and infer derived attributes. int current_operand = 0; for (const tensorflow::OpDef::ArgDef& input_arg : op_def->input_arg()) { if (!input_arg.number_attr().empty()) { // TODO(b/156122856): we don't support variadic operands. - s->status = tensorflow::errors::Unimplemented( + return tensorflow::errors::Unimplemented( "Unsupported 'number_attr' for '", input_arg.number_attr(), "'"); - return nullptr; } else if (!input_arg.type_list_attr().empty()) { - s->status = tensorflow::errors::InvalidArgument( + return tensorflow::errors::InvalidArgument( "Unsupported 'type_list_attr' for '", input_arg.number_attr(), "'"); - return nullptr; } if (current_operand >= operands.size()) { - s->status = tensorflow::errors::InvalidArgument("Missing operand for '", - input_arg.name(), "'"); - return nullptr; + return tensorflow::errors::InvalidArgument("Missing operand for '", + input_arg.name(), "'"); } Type expected_type; if (input_arg.type() != tensorflow::DT_INVALID) { - s->status = - ConvertDataTypeToTensor(input_arg.type(), builder, &expected_type); - if (!s->status.ok()) return nullptr; - if (input_arg.is_ref()) expected_type = AddRef(expected_type, s); - if (!s->status.ok()) return nullptr; + TF_RETURN_IF_ERROR( + ConvertDataTypeToTensor(input_arg.type(), builder, &expected_type)); + Type output_type; + if (input_arg.is_ref()) + TF_RETURN_IF_ERROR(AddRef(expected_type, &output_type)); + expected_type = output_type; } else { expected_type = operands[current_operand].getType(); } @@ -277,17 +351,15 @@ OperationState* MlirAbstractOp::Create(ArrayRef<Value> operands, TF_Status* s) { // Same type repeated "repeats" times. Attribute repeats_attr = attrs_[output_arg.number_attr()]; if (!repeats_attr) { - s->status = tensorflow::errors::InvalidArgument( + return tensorflow::errors::InvalidArgument( "Missing attribute '", output_arg.number_attr(), "' required for output list '", output_arg.name(), "'"); - return nullptr; } if (!repeats_attr.isa<IntegerAttr>()) { - s->status = tensorflow::errors::InvalidArgument( + return tensorflow::errors::InvalidArgument( "Attribute '", output_arg.number_attr(), "' required for output list '", output_arg.name(), "' isn't an integer"); - return nullptr; } int64_t repeats = repeats_attr.cast<IntegerAttr>().getInt(); @@ -295,102 +367,186 @@ OperationState* MlirAbstractOp::Create(ArrayRef<Value> operands, TF_Status* s) { // Same type repeated "repeats" times. Attribute attr = attrs_[output_arg.type_attr()]; if (!attr) { - s->status = tensorflow::errors::InvalidArgument( + return tensorflow::errors::InvalidArgument( "Missing attribute '", output_arg.type_attr(), "' required for output '", output_arg.name(), "'"); - return nullptr; } TypeAttr type_attr = attr.dyn_cast<TypeAttr>(); if (!type_attr) { - s->status = tensorflow::errors::InvalidArgument( + return tensorflow::errors::InvalidArgument( "Attribute '", output_arg.type_attr(), "' required for output '", output_arg.name(), "' isn't a type attribute"); - return nullptr; } for (int i = 0; i < repeats; ++i) state_->types.push_back(type_attr.getType()); } else if (output_arg.type() != tensorflow::DT_INVALID) { for (int i = 0; i < repeats; ++i) { Type type; - s->status = - ConvertDataTypeToTensor(output_arg.type(), builder, &type); - if (!s->status.ok()) return nullptr; + TF_RETURN_IF_ERROR( + ConvertDataTypeToTensor(output_arg.type(), builder, &type)); state_->types.push_back(type); } } else { - s->status = tensorflow::errors::InvalidArgument( + return tensorflow::errors::InvalidArgument( "Missing type or type_attr field in ", output_arg.ShortDebugString()); - return nullptr; } } else if (!output_arg.type_attr().empty()) { Attribute attr = attrs_[output_arg.type_attr()]; if (!attr) { - s->status = tensorflow::errors::InvalidArgument( + return tensorflow::errors::InvalidArgument( "Missing attribute '", output_arg.type_attr(), "' required for output '", output_arg.name(), "'"); - return nullptr; } TypeAttr type_attr = attr.dyn_cast<TypeAttr>(); if (!type_attr) { - s->status = tensorflow::errors::InvalidArgument( + return tensorflow::errors::InvalidArgument( "Attribute '", output_arg.type_attr(), "' required for output '", output_arg.name(), "' isn't a type attribute"); - return nullptr; } state_->types.push_back(type_attr.getValue()); } else if (!output_arg.type_list_attr().empty()) { // This is pointing to an attribute which is an array of types. Attribute attr = attrs_[output_arg.type_list_attr()]; if (!attr) { - s->status = tensorflow::errors::InvalidArgument( + return tensorflow::errors::InvalidArgument( "Missing attribute '", output_arg.type_list_attr(), "' required for output '", output_arg.name(), "'"); - return nullptr; } ArrayAttr array_attr = attr.dyn_cast<ArrayAttr>(); if (!array_attr) { - s->status = tensorflow::errors::InvalidArgument( + return tensorflow::errors::InvalidArgument( "Attribute '", output_arg.type_list_attr(), "' required for output '", output_arg.name(), "' isn't an array attribute"); - return nullptr; } for (Attribute attr : array_attr) { TypeAttr type_attr = attr.dyn_cast<TypeAttr>(); if (!type_attr) { - s->status = tensorflow::errors::InvalidArgument( + return tensorflow::errors::InvalidArgument( "Array Attribute '", output_arg.type_list_attr(), "' required for output '", output_arg.name(), "' has a non-Type element"); - return nullptr; } state_->types.push_back(type_attr.getValue()); } } else if (output_arg.type() != tensorflow::DT_INVALID) { Type type; Builder builder(context_); - s->status = ConvertDataTypeToTensor(output_arg.type(), builder, &type); - if (!s->status.ok()) return nullptr; + TF_RETURN_IF_ERROR( + ConvertDataTypeToTensor(output_arg.type(), builder, &type)); state_->types.push_back(type); } else { - s->status = tensorflow::errors::InvalidArgument( - "No type fields in ", output_arg.ShortDebugString()); - if (!s->status.ok()) return nullptr; + return tensorflow::errors::InvalidArgument("No type fields in ", + output_arg.ShortDebugString()); } if (output_arg.is_ref()) { // For all types that were added by this function call, make them refs. for (Type& type : llvm::make_range(&state_->types[original_size], state_->types.end())) { - type = AddRef(type, s); - if (!s->status.ok()) return nullptr; + Type output_type; + TF_RETURN_IF_ERROR(AddRef(type, &output_type)); + type = output_type; } } } - return state_.get(); + *state = state_.get(); + return Status::OK(); } -TF_Function* MlirFunction::GetTfFunction(TF_Status* s) { +const string& MlirAbstractOp::Name() const { return tf_op_type_; } + +const string& MlirAbstractOp::DeviceName() const { return device_name_; } + +Status MlirAbstractOp::SetDeviceName(const char* name) { + device_name_ = name; + return Status::OK(); +} + +Status MlirAbstractOp::AddInputList(absl::Span<AbstractTensorHandle*> inputs) { + return tensorflow::errors::Unimplemented( + "AddInputList has not been implemented yet."); +} + +Status MlirAbstractOp::SetAttrString(const char* attr_name, const char* data, + size_t length) { + return tensorflow::errors::Unimplemented( + "SetAttrString has not been implemented yet."); +} +Status MlirAbstractOp::SetAttrInt(const char* attr_name, int64_t value) { + return tensorflow::errors::Unimplemented( + "SetAttrInt has not been implemented yet."); +} +Status MlirAbstractOp::SetAttrFloat(const char* attr_name, float value) { + return tensorflow::errors::Unimplemented( + "SetAttrFloat has not been implemented yet."); +} +Status MlirAbstractOp::SetAttrBool(const char* attr_name, bool value) { + return tensorflow::errors::Unimplemented( + "SetAttrBool has not been implemented yet."); +} +Status MlirAbstractOp::SetAttrShape(const char* attr_name, const int64_t* dims, + const int num_dims) { + return tensorflow::errors::Unimplemented( + "SetAttrShape has not been implemented yet."); +} +Status MlirAbstractOp::SetAttrFunction(const char* attr_name, + const AbstractOperation* value) { + return tensorflow::errors::Unimplemented( + "SetAttrFunction has not been implemented yet."); +} +Status MlirAbstractOp::SetAttrFunctionName(const char* attr_name, + const char* value, size_t length) { + return tensorflow::errors::Unimplemented( + "SetAttrFunctionName has not been implemented yet."); +} +Status MlirAbstractOp::SetAttrTensor(const char* attr_name, + AbstractTensorInterface* tensor) { + return tensorflow::errors::Unimplemented( + "SetAttrTensor has not been implemented yet."); +} +Status MlirAbstractOp::SetAttrStringList(const char* attr_name, + const void* const* values, + const size_t* lengths, + int num_values) { + return tensorflow::errors::Unimplemented( + "SetAttrStringList has not been implemented yet."); +} +Status MlirAbstractOp::SetAttrFloatList(const char* attr_name, + const float* values, int num_values) { + return tensorflow::errors::Unimplemented( + "SetAttrFloatList has not been implemented yet."); +} +Status MlirAbstractOp::SetAttrIntList(const char* attr_name, + const int64_t* values, int num_values) { + return tensorflow::errors::Unimplemented( + "SetAttrIntList has not been implemented yet."); +} +Status MlirAbstractOp::SetAttrTypeList(const char* attr_name, + const tensorflow::DataType* values, + int num_values) { + return tensorflow::errors::Unimplemented( + "SetAttrTypeList has not been implemented yet."); +} +Status MlirAbstractOp::SetAttrBoolList(const char* attr_name, + const unsigned char* values, + int num_values) { + return tensorflow::errors::Unimplemented( + "SetAttrBoolList has not been implemented yet."); +} +Status MlirAbstractOp::SetAttrShapeList(const char* attr_name, + const int64_t** dims, + const int* num_dims, int num_values) { + return tensorflow::errors::Unimplemented( + "SetAttrShapeList has not been implemented yet."); +} +Status MlirAbstractOp::SetAttrFunctionList( + const char* attr_name, absl::Span<const AbstractOperation*> values) { + return tensorflow::errors::Unimplemented( + "SetAttrFunctionList has not been implemented yet."); +} + +Status MlirFunction::GetFunctionDef(tensorflow::FunctionDef** f) { PassManager pm(func_.getContext()); pm.addNestedPass<FuncOp>(CreateFunctionalToExecutorDialectConversionPass()); pm.addNestedPass<FuncOp>(CreateBreakUpIslandsPass()); @@ -400,75 +556,59 @@ TF_Function* MlirFunction::GetTfFunction(TF_Status* s) { StatusScopedDiagnosticHandler diag_handler(func_.getContext()); LogicalResult result = pm.run(func_.getParentOfType<ModuleOp>()); (void)result; - s->status = diag_handler.ConsumeStatus(); - if (!s->status.ok()) return nullptr; + TF_RETURN_IF_ERROR(diag_handler.ConsumeStatus()); tensorflow::GraphExportConfig configs; - std::unique_ptr<TF_Function> tf_function(new TF_Function); - s->status = ConvertMlirFunctionToFunctionLibraryDef(func_, configs, - &tf_function->fdef); - return tf_function.release(); + *f = new tensorflow::FunctionDef(); + return ConvertMlirFunctionToFunctionLibraryDef(func_, configs, *f); } -void MlirFunctionContext::ExecuteOperation(AbstractOp* abstract_op, - int num_inputs, - AbstractTensor* const* inputs, - OutputList* o, TF_Status* s) { - auto* mlir_op = dyncast<MlirAbstractOp>(abstract_op); - if (mlir_op == nullptr) { - s->status = tensorflow::errors::InvalidArgument( - "Unable to cast AbstractOp to TF_GraphOp."); - return; - } - SmallVector<Value, 8> operands; - for (int i = 0; i < num_inputs; ++i) { - auto* operand = dyncast<MlirTensor>(inputs[i]); - if (!operand) { - s->status = tensorflow::errors::InvalidArgument( - "Capturing eager tensors is not supported yet."); - return; - } - if (operand->getValue().getContext() != context_.get()) { - s->status = tensorflow::errors::InvalidArgument( - "Capturing tensors from other context is not supported."); - return; - } - operands.push_back(operand->getValue()); - } - OperationState* state = mlir_op->Create(operands, s); - if (!s->status.ok() || !state) return; - Operation* op = builder_.createOperation(*state); - int num_results = op->getNumResults(); - o->outputs.clear(); - o->outputs.reserve(num_results); - for (Value result : op->getResults()) - o->outputs.push_back(new MlirTensor(result)); +Status MlirAbstractOp::Execute(absl::Span<AbstractTensorHandle*> retvals, + int* num_retvals) { + OperationState* state; + TF_RETURN_IF_ERROR(Create(operands_, &state)); + Operation* op = function_context_->CreateOperationFromState(*state); + *num_retvals = op->getNumResults(); + for (int i = 0; i < *num_retvals; i++) + retvals[i] = new MlirTensor(op->getResult(i)); + return Status::OK(); } -AbstractTensor* MlirFunctionContext::AddParameter(TF_DataType dtype, - TF_Status* s) { +Operation* MlirFunctionContext::CreateOperationFromState( + const OperationState& state) { + return builder_.createOperation(state); +} + +Status MlirFunctionContext::AddParameter(tensorflow::DataType dtype, + TracingTensorHandle** handle) { Type type; - s->status = ConvertDataTypeToTensor(static_cast<tensorflow::DataType>(dtype), - builder_, &type); - if (!s->status.ok()) return nullptr; - return new MlirTensor(func_.getBody().front().addArgument(type)); + TF_RETURN_IF_ERROR(ConvertDataTypeToTensor(dtype, builder_, &type)); + *handle = new MlirTensor(func_.getBody().front().addArgument(type)); + return Status::OK(); } -AbstractFunction* MlirFunctionContext::Finalize(OutputList* outputs, - TF_Status* s) { +Status MlirAbstractOp::AddInput(AbstractTensorHandle* input) { + auto* operand = dyn_cast<MlirTensor>(input); + if (!operand) { + return tensorflow::errors::InvalidArgument( + "Unable to cast input to MlirTensor"); + } + operands_.push_back(operand->getValue()); + return Status::OK(); +} +Status MlirFunctionContext::Finalize(OutputList* outputs, + AbstractFunction** f) { Block& body = func_.getBody().front(); SmallVector<Value, 8> ret_operands; - for (AbstractTensor* output : outputs->outputs) { - auto* operand = dyncast<MlirTensor>(output); + for (auto* output : outputs->outputs) { + auto* operand = dyn_cast<MlirTensor>(output); if (!operand) { - s->status = tensorflow::errors::InvalidArgument( + return tensorflow::errors::InvalidArgument( "Capturing eager tensors is not supported yet."); - return nullptr; } if (operand->getValue().getContext() != context_.get()) { - s->status = tensorflow::errors::InvalidArgument( + return tensorflow::errors::InvalidArgument( "Capturing tensors from other context is not supported."); - return nullptr; } ret_operands.push_back(operand->getValue()); } @@ -478,16 +618,17 @@ AbstractFunction* MlirFunctionContext::Finalize(OutputList* outputs, auto result_types = llvm::to_vector<8>(body.getTerminator()->getOperandTypes()); func_.setType(FunctionType::get(arg_types, result_types, func_.getContext())); - return new MlirFunction(std::move(context_), std::move(module_), func_); + *f = new MlirFunction(std::move(context_), std::move(module_), func_); + return Status::OK(); } extern "C" { -ExecutionContext* MlirTracingFactory(const char* fn_name, TF_Status* s) { +TracingContext* MlirTracingFactory(const char* fn_name, TF_Status* s) { RegisterDialects(); return new MlirFunctionContext(fn_name); } } -} // end anonymous namespace -} // end namespace TF -} // end namespace mlir +} // namespace +} // namespace TF +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir_registration.cc b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir_registration.cc index 778f4b777a3..01a079b5247 100644 --- a/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir_registration.cc +++ b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir_registration.cc @@ -15,10 +15,10 @@ limitations under the License. #include "tensorflow/c/eager/c_api_unified_experimental_internal.h" -using tensorflow::internal::ExecutionContext; +using tensorflow::tracing::TracingContext; extern "C" { -ExecutionContext* MlirTracingFactory(const char* fn_name, TF_Status* s); +TracingContext* MlirTracingFactory(const char* fn_name, TF_Status* s); } namespace {