diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 5a39c17e1d9..865f8061ae0 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -1397,23 +1397,17 @@ void TFE_ContextAddFunctionDef(TFE_Context* ctx, tensorflow::errors::InvalidArgument("Invalid FunctionDef proto"); return; } - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - status->status = context->AddFunctionDef(function_def); + status->status = tensorflow::unwrap(ctx)->AddFunctionDef(function_def); } void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function, TF_Status* status) { - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - status->status = context->AddFunctionDef(function->fdef); + status->status = tensorflow::unwrap(ctx)->AddFunctionDef(function->fdef); } void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name, TF_Status* status) { - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - status->status = context->RemoveFunction(name); + status->status = tensorflow::unwrap(ctx)->RemoveFunction(name); } unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) { diff --git a/tensorflow/c/eager/context_interface.h b/tensorflow/c/eager/context_interface.h index 2861fa43b66..9cd0669ff68 100644 --- a/tensorflow/c/eager/context_interface.h +++ b/tensorflow/c/eager/context_interface.h @@ -104,6 +104,14 @@ class AbstractContextInterface { // Block until all pending nodes are finished. virtual Status AsyncWait() = 0; + // Add a function (serialized FunctionDef protocol buffer) so that it can + // be executed as an op. Return error if the function with the same name + // already exists. + virtual Status AddFunctionDef(const FunctionDef& fdef) = 0; + // Remove a function. 'func' argument is the name of a previously added + // FunctionDef. The name is in fdef.signature.name. + virtual Status RemoveFunction(const string& func) = 0; + protected: virtual ~AbstractContextInterface() {} }; diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index 3aa700d3718..89188352677 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -3572,6 +3572,24 @@ StatusOr ConvertGraphToMlir( /*func_name=*/"main"); } +stream_executor::port::StatusOr ConvertFunctionToMlir( + mlir::StringRef name, const FunctionLibraryDefinition& flib_def, + mlir::MLIRContext* context) { + const tensorflow::FunctionDef* fdef = flib_def.Find(name.str()); + if (fdef == nullptr) + return tensorflow::errors::NotFound("Cannot find function ", name.str()); + + std::unique_ptr fbody; + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, tensorflow::AttrSlice(), + &flib_def, &fbody)); + + tensorflow::GraphDebugInfo dummy_debug_info; + tensorflow::GraphImportConfig specs; + specs.graph_as_function = true; + return GraphDefImporter::Convert(context, *fbody->graph, dummy_debug_info, + flib_def, specs, name); +} + StatusOr ConvertSavedModelToMlir( SavedModelV2Bundle* saved_model, mlir::MLIRContext* context, absl::Span exported_names, bool add_default_attributes) { diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.h b/tensorflow/compiler/mlir/tensorflow/translate/import_model.h index bdb72345201..80001c44389 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.h @@ -20,6 +20,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/cc/saved_model/bundle_v2.h" #include "tensorflow/cc/saved_model/loader.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" @@ -45,6 +46,13 @@ stream_executor::port::StatusOr ConvertGraphToMlir( const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs, mlir::MLIRContext* context); +// [Experimental] +// Given a Function, returns a MLIR module containing the graph, expressed with +// tf_executor dialect. +stream_executor::port::StatusOr ConvertFunctionToMlir( + mlir::StringRef name, const FunctionLibraryDefinition& flib_def, + mlir::MLIRContext* context); + // Given a SavedModel, returns a MLIR module containing the functions, expressed // with tf_executor dialect. stream_executor::port::StatusOr ConvertSavedModelToMlir( diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index cceb883a965..e902ede7f15 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -275,7 +275,7 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted { // Add the given `fdef` to the local FunctionLibraryDefinition. And add an // entry to the KernelAndDevice cache for it if it's not exist. - Status AddFunctionDef(const FunctionDef& fdef); + Status AddFunctionDef(const FunctionDef& fdef) override; // `library` contains all FunctionDefs and GradientDefs to expand `fdef`. Add // it to the local FunctionLibraryDefinition as well, but no need to add it // to the KernelAndDevice cache since they won't be executed as @@ -286,7 +286,7 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted { const FunctionDef* GetFunctionDef(const string& function_name); - Status RemoveFunction(const string& func); + Status RemoveFunction(const string& func) override; // Wait for pending nodes to be finished in local executors (including context // default executor and thread executors) and executors on remote workers. diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index a44d8a493c1..014cb97f72f 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -803,6 +803,7 @@ cuda_py_test( name = "def_function_test", srcs = ["def_function_test.py"], python_version = "PY3", + tfrt_enabled = True, deps = [ ":def_function", "//tensorflow/python:client_testlib", diff --git a/tensorflow/python/eager/benchmarks_test.py b/tensorflow/python/eager/benchmarks_test.py index 350dddb391e..315a4cfd056 100644 --- a/tensorflow/python/eager/benchmarks_test.py +++ b/tensorflow/python/eager/benchmarks_test.py @@ -707,7 +707,6 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase): self._benchmark_tfe_py_execute_matmul( m, transpose_b=True, num_iters=self._num_iters_100_by_784) - @test_util.disable_tfrt("Graph is not supported yet. b/156187905") def benchmark_defun_matmul_100_by_784_CPU(self): with context.device(CPU): m = self._m_100_by_784.cpu() diff --git a/tensorflow/python/eager/def_function_test.py b/tensorflow/python/eager/def_function_test.py index d5daa3acc99..0549da2c256 100644 --- a/tensorflow/python/eager/def_function_test.py +++ b/tensorflow/python/eager/def_function_test.py @@ -68,6 +68,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): self.assertAllEqual(fn(constant_op.constant(4.0)), 8.0) + @test_util.disable_tfrt('Variable argument is not supported') def testFailIfVariablesAreCreatedMoreThanOnce(self): @def_function.function @@ -77,6 +78,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): with self.assertRaises(ValueError): fn(1.0) + @test_util.disable_tfrt('Variable argument is not supported') def testFailIfVariablesAreCreatedMoreThanOnceNoWeakRef(self): state = [] @@ -96,6 +98,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): self.assertAllEqual(f(range(5)), 1.0) + @test_util.disable_tfrt('Variable argument is not supported') def testCorrectVariableCreation(self): state = [] @@ -109,6 +112,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): self.assertAllEqual(fn(constant_op.constant(1.0)), 2.0) self.assertAllEqual(fn(constant_op.constant(3.0)), 6.0) + @test_util.disable_tfrt('Variable argument is not supported') def testFunctionInitializer(self): state = [] @@ -121,6 +125,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): self.assertAllEqual(fn(constant_op.constant(1.0)), 2.0) + @test_util.disable_tfrt('Variable argument is not supported') def testFunctionMultipleVariableInitializer(self): state = [] @@ -134,6 +139,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): self.assertAllEqual(fn(constant_op.constant(1.0)), [2.0, 5.0]) + @test_util.disable_tfrt('Variable argument is not supported') def testFunctionInitializationFunction(self): state = [] @@ -151,6 +157,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): init_fn() self.assertEqual(state[0].numpy(), 2.0) + @test_util.disable_tfrt('Variable argument is not supported') def testVariableInitializerNotConstant(self): state = [] @@ -180,6 +187,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): self.assertAllEqual(sess.run(state[0]), 2.0) self.assertAllEqual(self.evaluate(result), 6.0) + @test_util.disable_tfrt('Variable argument is not supported') def testLegacyGraphModeVariablesNonTrivialInitializer(self): with ops.Graph().as_default(), self.test_session() as sess: state = [] @@ -199,6 +207,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): self.assertAllEqual(sess.run(state[0]), 6.0) self.assertAllEqual(self.evaluate(result), 18.0) + @test_util.disable_tfrt('Variable argument is not supported') def testLegacyGraphModeInputDependentInitializerFails(self): with ops.Graph().as_default(): state = [] @@ -213,6 +222,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): lift_to_graph.UnliftableError, r'transitively.* mul .* x'): fn(constant_op.constant(3.0)) + @test_util.disable_tfrt('Variable argument is not supported') def testMethod(self): class MyModel(object): @@ -241,6 +251,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): def_function.function(functools.partial(lambda x, y: x + y, 1.))( constant_op.constant(2.))) + @test_util.disable_tfrt('Partial is not supported') def test_functools_partial_new_default(self): def f(x=3, y=7): return x + y @@ -249,6 +260,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): self.assertEqual(func().numpy(), 9) self.assertEqual(func(y=8).numpy(), 11) + @test_util.disable_tfrt('Partial is not supported') def test_functools_partial_keywords(self): def f(x, y): return x + y @@ -257,6 +269,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): functools.partial(f, x=array_ops.zeros([1]), y=array_ops.zeros([1]))) self.assertAllEqual(func(), [0.0]) + @test_util.disable_tfrt('Partial is not supported') def test_functools_partial_single_positional(self): def f(x, y): return x + y @@ -265,6 +278,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): functools.partial(f, constant_op.constant(1))) self.assertAllEqual(func(5), 6) + @test_util.disable_tfrt('Partial is not supported') def test_complicated_partial_with_defaults(self): def identity(*args): @@ -312,6 +326,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): (tensor_spec.TensorSpec( None, dtypes.float32, name='x'),)) + @test_util.disable_tfrt('Variable argument is not supported') @test_util.run_in_graph_and_eager_modes def test_variable_naming(self): class HasVars(module.Module): @@ -382,6 +397,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): 'defined in another function or code block'): f(array_ops.zeros(shape=(8, 42, 3))) + @test_util.disable_tfrt('Control flow is not supported') def testRuntimeErrorNotSticky(self): @def_function.function @@ -486,6 +502,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): constant_op.constant(3.), constant_op.constant(4.))) + @test_util.disable_tfrt('Variable argument is not supported') def testVariableCreatorScope(self): created_variables = [] captured_variables = [] @@ -505,6 +522,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): f() self.assertEqual(created_variables, captured_variables) + @test_util.disable_tfrt('Variable argument is not supported') def testVarAlreadyInitializedNoClobbering(self): v_holder = [] @@ -522,6 +540,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): add_var.get_concrete_function(constant_op.constant(2.)) self.assertAllClose([13., 14.], add_var(constant_op.constant(2.))) + @test_util.disable_tfrt('Variable argument is not supported') def testSameVariableTwice(self): v = variables.Variable(1.0) @@ -531,6 +550,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): self.assertAllEqual(add(v, v), 2.0) + @test_util.disable_tfrt('Variable argument is not supported') def testVariableUpdate(self): v1 = variables.Variable(1.0) v2 = variables.Variable(2.0) @@ -566,6 +586,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): self.assertIs(func_a, func_b) + @test_util.disable_tfrt('Nested function is not supported') def testInitializationInNestedCall(self): v_holder = [] @@ -588,6 +609,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): v_holder[1].assign(11.) self.assertAllClose([14., 15.], wrapper(constant_op.constant(2.))) + @test_util.disable_tfrt('Variable argument is not supported') @test_util.run_gpu_only def testDeviceAnnotationRespected(self): a = [] @@ -607,6 +629,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): create_variable() self.assertRegexpMatches(a[0].device, 'CPU') + @test_util.disable_tfrt('Variable argument is not supported') @test_util.run_gpu_only def testDeviceAnnotationForInitializerRespected(self): a = [] @@ -681,6 +704,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): self.assertEqual(self.evaluate(cloned(x)), self.evaluate(cloned_py_function(x))) + @test_util.disable_tfrt('Variable argument is not supported') def testLiftPlaceholderInitializedVariable(self): with ops.Graph().as_default(): var_list = [] @@ -834,6 +858,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase): self.assertLen(logs.output, 1) self.assertIn('Tracing is expensive', logs.output[0]) + @test_util.disable_tfrt('Nested function is not supported') def test_frequent_retracing_warning_nested(self): if sys.version_info[0] < 3: self.skipTest('self.assertLogs() call is not available in Python 2.')