Initial @tf.function support in TFRT. And enable some tests in def_function_test.py.

Some limitations of current implementation:
1. We doesn't have a cache for compiled function. Right now it JIT compile the function in each invocation.
2. Does not support nested function.
3. Does not support variable.

PiperOrigin-RevId: 315031999
Change-Id: I8d96ed26d0da7c071b7f89e65d6ada7bbc290a37
This commit is contained in:
Xiao Yu 2020-06-05 18:32:35 -07:00 committed by TensorFlower Gardener
parent e59ee30dd2
commit 54d74710e4
8 changed files with 65 additions and 12 deletions

View File

@ -1397,23 +1397,17 @@ void TFE_ContextAddFunctionDef(TFE_Context* ctx,
tensorflow::errors::InvalidArgument("Invalid FunctionDef proto");
return;
}
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
status->status = context->AddFunctionDef(function_def);
status->status = tensorflow::unwrap(ctx)->AddFunctionDef(function_def);
}
void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
TF_Status* status) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
status->status = context->AddFunctionDef(function->fdef);
status->status = tensorflow::unwrap(ctx)->AddFunctionDef(function->fdef);
}
void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name,
TF_Status* status) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
status->status = context->RemoveFunction(name);
status->status = tensorflow::unwrap(ctx)->RemoveFunction(name);
}
unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) {

View File

@ -104,6 +104,14 @@ class AbstractContextInterface {
// Block until all pending nodes are finished.
virtual Status AsyncWait() = 0;
// Add a function (serialized FunctionDef protocol buffer) so that it can
// be executed as an op. Return error if the function with the same name
// already exists.
virtual Status AddFunctionDef(const FunctionDef& fdef) = 0;
// Remove a function. 'func' argument is the name of a previously added
// FunctionDef. The name is in fdef.signature.name.
virtual Status RemoveFunction(const string& func) = 0;
protected:
virtual ~AbstractContextInterface() {}
};

View File

@ -3572,6 +3572,24 @@ StatusOr<mlir::OwningModuleRef> ConvertGraphToMlir(
/*func_name=*/"main");
}
stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertFunctionToMlir(
mlir::StringRef name, const FunctionLibraryDefinition& flib_def,
mlir::MLIRContext* context) {
const tensorflow::FunctionDef* fdef = flib_def.Find(name.str());
if (fdef == nullptr)
return tensorflow::errors::NotFound("Cannot find function ", name.str());
std::unique_ptr<tensorflow::FunctionBody> fbody;
TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, tensorflow::AttrSlice(),
&flib_def, &fbody));
tensorflow::GraphDebugInfo dummy_debug_info;
tensorflow::GraphImportConfig specs;
specs.graph_as_function = true;
return GraphDefImporter::Convert(context, *fbody->graph, dummy_debug_info,
flib_def, specs, name);
}
StatusOr<mlir::OwningModuleRef> ConvertSavedModelToMlir(
SavedModelV2Bundle* saved_model, mlir::MLIRContext* context,
absl::Span<std::string> exported_names, bool add_default_attributes) {

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "tensorflow/cc/saved_model/bundle_v2.h"
#include "tensorflow/cc/saved_model/loader.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
@ -45,6 +46,13 @@ stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertGraphToMlir(
const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs,
mlir::MLIRContext* context);
// [Experimental]
// Given a Function, returns a MLIR module containing the graph, expressed with
// tf_executor dialect.
stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertFunctionToMlir(
mlir::StringRef name, const FunctionLibraryDefinition& flib_def,
mlir::MLIRContext* context);
// Given a SavedModel, returns a MLIR module containing the functions, expressed
// with tf_executor dialect.
stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertSavedModelToMlir(

View File

@ -275,7 +275,7 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted {
// Add the given `fdef` to the local FunctionLibraryDefinition. And add an
// entry to the KernelAndDevice cache for it if it's not exist.
Status AddFunctionDef(const FunctionDef& fdef);
Status AddFunctionDef(const FunctionDef& fdef) override;
// `library` contains all FunctionDefs and GradientDefs to expand `fdef`. Add
// it to the local FunctionLibraryDefinition as well, but no need to add it
// to the KernelAndDevice cache since they won't be executed as
@ -286,7 +286,7 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted {
const FunctionDef* GetFunctionDef(const string& function_name);
Status RemoveFunction(const string& func);
Status RemoveFunction(const string& func) override;
// Wait for pending nodes to be finished in local executors (including context
// default executor and thread executors) and executors on remote workers.

View File

@ -803,6 +803,7 @@ cuda_py_test(
name = "def_function_test",
srcs = ["def_function_test.py"],
python_version = "PY3",
tfrt_enabled = True,
deps = [
":def_function",
"//tensorflow/python:client_testlib",

View File

@ -707,7 +707,6 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
self._benchmark_tfe_py_execute_matmul(
m, transpose_b=True, num_iters=self._num_iters_100_by_784)
@test_util.disable_tfrt("Graph is not supported yet. b/156187905")
def benchmark_defun_matmul_100_by_784_CPU(self):
with context.device(CPU):
m = self._m_100_by_784.cpu()

View File

@ -68,6 +68,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
self.assertAllEqual(fn(constant_op.constant(4.0)), 8.0)
@test_util.disable_tfrt('Variable argument is not supported')
def testFailIfVariablesAreCreatedMoreThanOnce(self):
@def_function.function
@ -77,6 +78,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
with self.assertRaises(ValueError):
fn(1.0)
@test_util.disable_tfrt('Variable argument is not supported')
def testFailIfVariablesAreCreatedMoreThanOnceNoWeakRef(self):
state = []
@ -96,6 +98,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
self.assertAllEqual(f(range(5)), 1.0)
@test_util.disable_tfrt('Variable argument is not supported')
def testCorrectVariableCreation(self):
state = []
@ -109,6 +112,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
self.assertAllEqual(fn(constant_op.constant(1.0)), 2.0)
self.assertAllEqual(fn(constant_op.constant(3.0)), 6.0)
@test_util.disable_tfrt('Variable argument is not supported')
def testFunctionInitializer(self):
state = []
@ -121,6 +125,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
self.assertAllEqual(fn(constant_op.constant(1.0)), 2.0)
@test_util.disable_tfrt('Variable argument is not supported')
def testFunctionMultipleVariableInitializer(self):
state = []
@ -134,6 +139,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
self.assertAllEqual(fn(constant_op.constant(1.0)), [2.0, 5.0])
@test_util.disable_tfrt('Variable argument is not supported')
def testFunctionInitializationFunction(self):
state = []
@ -151,6 +157,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
init_fn()
self.assertEqual(state[0].numpy(), 2.0)
@test_util.disable_tfrt('Variable argument is not supported')
def testVariableInitializerNotConstant(self):
state = []
@ -180,6 +187,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
self.assertAllEqual(sess.run(state[0]), 2.0)
self.assertAllEqual(self.evaluate(result), 6.0)
@test_util.disable_tfrt('Variable argument is not supported')
def testLegacyGraphModeVariablesNonTrivialInitializer(self):
with ops.Graph().as_default(), self.test_session() as sess:
state = []
@ -199,6 +207,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
self.assertAllEqual(sess.run(state[0]), 6.0)
self.assertAllEqual(self.evaluate(result), 18.0)
@test_util.disable_tfrt('Variable argument is not supported')
def testLegacyGraphModeInputDependentInitializerFails(self):
with ops.Graph().as_default():
state = []
@ -213,6 +222,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
lift_to_graph.UnliftableError, r'transitively.* mul .* x'):
fn(constant_op.constant(3.0))
@test_util.disable_tfrt('Variable argument is not supported')
def testMethod(self):
class MyModel(object):
@ -241,6 +251,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
def_function.function(functools.partial(lambda x, y: x + y, 1.))(
constant_op.constant(2.)))
@test_util.disable_tfrt('Partial is not supported')
def test_functools_partial_new_default(self):
def f(x=3, y=7):
return x + y
@ -249,6 +260,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
self.assertEqual(func().numpy(), 9)
self.assertEqual(func(y=8).numpy(), 11)
@test_util.disable_tfrt('Partial is not supported')
def test_functools_partial_keywords(self):
def f(x, y):
return x + y
@ -257,6 +269,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
functools.partial(f, x=array_ops.zeros([1]), y=array_ops.zeros([1])))
self.assertAllEqual(func(), [0.0])
@test_util.disable_tfrt('Partial is not supported')
def test_functools_partial_single_positional(self):
def f(x, y):
return x + y
@ -265,6 +278,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
functools.partial(f, constant_op.constant(1)))
self.assertAllEqual(func(5), 6)
@test_util.disable_tfrt('Partial is not supported')
def test_complicated_partial_with_defaults(self):
def identity(*args):
@ -312,6 +326,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
(tensor_spec.TensorSpec(
None, dtypes.float32, name='x'),))
@test_util.disable_tfrt('Variable argument is not supported')
@test_util.run_in_graph_and_eager_modes
def test_variable_naming(self):
class HasVars(module.Module):
@ -382,6 +397,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
'defined in another function or code block'):
f(array_ops.zeros(shape=(8, 42, 3)))
@test_util.disable_tfrt('Control flow is not supported')
def testRuntimeErrorNotSticky(self):
@def_function.function
@ -486,6 +502,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
constant_op.constant(3.),
constant_op.constant(4.)))
@test_util.disable_tfrt('Variable argument is not supported')
def testVariableCreatorScope(self):
created_variables = []
captured_variables = []
@ -505,6 +522,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
f()
self.assertEqual(created_variables, captured_variables)
@test_util.disable_tfrt('Variable argument is not supported')
def testVarAlreadyInitializedNoClobbering(self):
v_holder = []
@ -522,6 +540,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
add_var.get_concrete_function(constant_op.constant(2.))
self.assertAllClose([13., 14.], add_var(constant_op.constant(2.)))
@test_util.disable_tfrt('Variable argument is not supported')
def testSameVariableTwice(self):
v = variables.Variable(1.0)
@ -531,6 +550,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
self.assertAllEqual(add(v, v), 2.0)
@test_util.disable_tfrt('Variable argument is not supported')
def testVariableUpdate(self):
v1 = variables.Variable(1.0)
v2 = variables.Variable(2.0)
@ -566,6 +586,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
self.assertIs(func_a, func_b)
@test_util.disable_tfrt('Nested function is not supported')
def testInitializationInNestedCall(self):
v_holder = []
@ -588,6 +609,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
v_holder[1].assign(11.)
self.assertAllClose([14., 15.], wrapper(constant_op.constant(2.)))
@test_util.disable_tfrt('Variable argument is not supported')
@test_util.run_gpu_only
def testDeviceAnnotationRespected(self):
a = []
@ -607,6 +629,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
create_variable()
self.assertRegexpMatches(a[0].device, 'CPU')
@test_util.disable_tfrt('Variable argument is not supported')
@test_util.run_gpu_only
def testDeviceAnnotationForInitializerRespected(self):
a = []
@ -681,6 +704,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
self.assertEqual(self.evaluate(cloned(x)),
self.evaluate(cloned_py_function(x)))
@test_util.disable_tfrt('Variable argument is not supported')
def testLiftPlaceholderInitializedVariable(self):
with ops.Graph().as_default():
var_list = []
@ -834,6 +858,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
self.assertLen(logs.output, 1)
self.assertIn('Tracing is expensive', logs.output[0])
@test_util.disable_tfrt('Nested function is not supported')
def test_frequent_retracing_warning_nested(self):
if sys.version_info[0] < 3:
self.skipTest('self.assertLogs() call is not available in Python 2.')