Initial @tf.function support in TFRT. And enable some tests in def_function_test.py.
Some limitations of current implementation: 1. We doesn't have a cache for compiled function. Right now it JIT compile the function in each invocation. 2. Does not support nested function. 3. Does not support variable. PiperOrigin-RevId: 315031999 Change-Id: I8d96ed26d0da7c071b7f89e65d6ada7bbc290a37
This commit is contained in:
parent
e59ee30dd2
commit
54d74710e4
@ -1397,23 +1397,17 @@ void TFE_ContextAddFunctionDef(TFE_Context* ctx,
|
||||
tensorflow::errors::InvalidArgument("Invalid FunctionDef proto");
|
||||
return;
|
||||
}
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
status->status = context->AddFunctionDef(function_def);
|
||||
status->status = tensorflow::unwrap(ctx)->AddFunctionDef(function_def);
|
||||
}
|
||||
|
||||
void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
|
||||
TF_Status* status) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
status->status = context->AddFunctionDef(function->fdef);
|
||||
status->status = tensorflow::unwrap(ctx)->AddFunctionDef(function->fdef);
|
||||
}
|
||||
|
||||
void TFE_ContextRemoveFunction(TFE_Context* ctx, const char* name,
|
||||
TF_Status* status) {
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
status->status = context->RemoveFunction(name);
|
||||
status->status = tensorflow::unwrap(ctx)->RemoveFunction(name);
|
||||
}
|
||||
|
||||
unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) {
|
||||
|
@ -104,6 +104,14 @@ class AbstractContextInterface {
|
||||
// Block until all pending nodes are finished.
|
||||
virtual Status AsyncWait() = 0;
|
||||
|
||||
// Add a function (serialized FunctionDef protocol buffer) so that it can
|
||||
// be executed as an op. Return error if the function with the same name
|
||||
// already exists.
|
||||
virtual Status AddFunctionDef(const FunctionDef& fdef) = 0;
|
||||
// Remove a function. 'func' argument is the name of a previously added
|
||||
// FunctionDef. The name is in fdef.signature.name.
|
||||
virtual Status RemoveFunction(const string& func) = 0;
|
||||
|
||||
protected:
|
||||
virtual ~AbstractContextInterface() {}
|
||||
};
|
||||
|
@ -3572,6 +3572,24 @@ StatusOr<mlir::OwningModuleRef> ConvertGraphToMlir(
|
||||
/*func_name=*/"main");
|
||||
}
|
||||
|
||||
stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertFunctionToMlir(
|
||||
mlir::StringRef name, const FunctionLibraryDefinition& flib_def,
|
||||
mlir::MLIRContext* context) {
|
||||
const tensorflow::FunctionDef* fdef = flib_def.Find(name.str());
|
||||
if (fdef == nullptr)
|
||||
return tensorflow::errors::NotFound("Cannot find function ", name.str());
|
||||
|
||||
std::unique_ptr<tensorflow::FunctionBody> fbody;
|
||||
TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, tensorflow::AttrSlice(),
|
||||
&flib_def, &fbody));
|
||||
|
||||
tensorflow::GraphDebugInfo dummy_debug_info;
|
||||
tensorflow::GraphImportConfig specs;
|
||||
specs.graph_as_function = true;
|
||||
return GraphDefImporter::Convert(context, *fbody->graph, dummy_debug_info,
|
||||
flib_def, specs, name);
|
||||
}
|
||||
|
||||
StatusOr<mlir::OwningModuleRef> ConvertSavedModelToMlir(
|
||||
SavedModelV2Bundle* saved_model, mlir::MLIRContext* context,
|
||||
absl::Span<std::string> exported_names, bool add_default_attributes) {
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "tensorflow/cc/saved_model/bundle_v2.h"
|
||||
#include "tensorflow/cc/saved_model/loader.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
|
||||
@ -45,6 +46,13 @@ stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertGraphToMlir(
|
||||
const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs,
|
||||
mlir::MLIRContext* context);
|
||||
|
||||
// [Experimental]
|
||||
// Given a Function, returns a MLIR module containing the graph, expressed with
|
||||
// tf_executor dialect.
|
||||
stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertFunctionToMlir(
|
||||
mlir::StringRef name, const FunctionLibraryDefinition& flib_def,
|
||||
mlir::MLIRContext* context);
|
||||
|
||||
// Given a SavedModel, returns a MLIR module containing the functions, expressed
|
||||
// with tf_executor dialect.
|
||||
stream_executor::port::StatusOr<mlir::OwningModuleRef> ConvertSavedModelToMlir(
|
||||
|
@ -275,7 +275,7 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted {
|
||||
|
||||
// Add the given `fdef` to the local FunctionLibraryDefinition. And add an
|
||||
// entry to the KernelAndDevice cache for it if it's not exist.
|
||||
Status AddFunctionDef(const FunctionDef& fdef);
|
||||
Status AddFunctionDef(const FunctionDef& fdef) override;
|
||||
// `library` contains all FunctionDefs and GradientDefs to expand `fdef`. Add
|
||||
// it to the local FunctionLibraryDefinition as well, but no need to add it
|
||||
// to the KernelAndDevice cache since they won't be executed as
|
||||
@ -286,7 +286,7 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted {
|
||||
|
||||
const FunctionDef* GetFunctionDef(const string& function_name);
|
||||
|
||||
Status RemoveFunction(const string& func);
|
||||
Status RemoveFunction(const string& func) override;
|
||||
|
||||
// Wait for pending nodes to be finished in local executors (including context
|
||||
// default executor and thread executors) and executors on remote workers.
|
||||
|
@ -803,6 +803,7 @@ cuda_py_test(
|
||||
name = "def_function_test",
|
||||
srcs = ["def_function_test.py"],
|
||||
python_version = "PY3",
|
||||
tfrt_enabled = True,
|
||||
deps = [
|
||||
":def_function",
|
||||
"//tensorflow/python:client_testlib",
|
||||
|
@ -707,7 +707,6 @@ class MicroBenchmarks(benchmarks_test_base.MicroBenchmarksBase):
|
||||
self._benchmark_tfe_py_execute_matmul(
|
||||
m, transpose_b=True, num_iters=self._num_iters_100_by_784)
|
||||
|
||||
@test_util.disable_tfrt("Graph is not supported yet. b/156187905")
|
||||
def benchmark_defun_matmul_100_by_784_CPU(self):
|
||||
with context.device(CPU):
|
||||
m = self._m_100_by_784.cpu()
|
||||
|
@ -68,6 +68,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
self.assertAllEqual(fn(constant_op.constant(4.0)), 8.0)
|
||||
|
||||
@test_util.disable_tfrt('Variable argument is not supported')
|
||||
def testFailIfVariablesAreCreatedMoreThanOnce(self):
|
||||
|
||||
@def_function.function
|
||||
@ -77,6 +78,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
fn(1.0)
|
||||
|
||||
@test_util.disable_tfrt('Variable argument is not supported')
|
||||
def testFailIfVariablesAreCreatedMoreThanOnceNoWeakRef(self):
|
||||
state = []
|
||||
|
||||
@ -96,6 +98,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
self.assertAllEqual(f(range(5)), 1.0)
|
||||
|
||||
@test_util.disable_tfrt('Variable argument is not supported')
|
||||
def testCorrectVariableCreation(self):
|
||||
|
||||
state = []
|
||||
@ -109,6 +112,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertAllEqual(fn(constant_op.constant(1.0)), 2.0)
|
||||
self.assertAllEqual(fn(constant_op.constant(3.0)), 6.0)
|
||||
|
||||
@test_util.disable_tfrt('Variable argument is not supported')
|
||||
def testFunctionInitializer(self):
|
||||
|
||||
state = []
|
||||
@ -121,6 +125,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
self.assertAllEqual(fn(constant_op.constant(1.0)), 2.0)
|
||||
|
||||
@test_util.disable_tfrt('Variable argument is not supported')
|
||||
def testFunctionMultipleVariableInitializer(self):
|
||||
|
||||
state = []
|
||||
@ -134,6 +139,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
self.assertAllEqual(fn(constant_op.constant(1.0)), [2.0, 5.0])
|
||||
|
||||
@test_util.disable_tfrt('Variable argument is not supported')
|
||||
def testFunctionInitializationFunction(self):
|
||||
|
||||
state = []
|
||||
@ -151,6 +157,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
|
||||
init_fn()
|
||||
self.assertEqual(state[0].numpy(), 2.0)
|
||||
|
||||
@test_util.disable_tfrt('Variable argument is not supported')
|
||||
def testVariableInitializerNotConstant(self):
|
||||
|
||||
state = []
|
||||
@ -180,6 +187,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertAllEqual(sess.run(state[0]), 2.0)
|
||||
self.assertAllEqual(self.evaluate(result), 6.0)
|
||||
|
||||
@test_util.disable_tfrt('Variable argument is not supported')
|
||||
def testLegacyGraphModeVariablesNonTrivialInitializer(self):
|
||||
with ops.Graph().as_default(), self.test_session() as sess:
|
||||
state = []
|
||||
@ -199,6 +207,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertAllEqual(sess.run(state[0]), 6.0)
|
||||
self.assertAllEqual(self.evaluate(result), 18.0)
|
||||
|
||||
@test_util.disable_tfrt('Variable argument is not supported')
|
||||
def testLegacyGraphModeInputDependentInitializerFails(self):
|
||||
with ops.Graph().as_default():
|
||||
state = []
|
||||
@ -213,6 +222,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
|
||||
lift_to_graph.UnliftableError, r'transitively.* mul .* x'):
|
||||
fn(constant_op.constant(3.0))
|
||||
|
||||
@test_util.disable_tfrt('Variable argument is not supported')
|
||||
def testMethod(self):
|
||||
|
||||
class MyModel(object):
|
||||
@ -241,6 +251,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
|
||||
def_function.function(functools.partial(lambda x, y: x + y, 1.))(
|
||||
constant_op.constant(2.)))
|
||||
|
||||
@test_util.disable_tfrt('Partial is not supported')
|
||||
def test_functools_partial_new_default(self):
|
||||
def f(x=3, y=7):
|
||||
return x + y
|
||||
@ -249,6 +260,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertEqual(func().numpy(), 9)
|
||||
self.assertEqual(func(y=8).numpy(), 11)
|
||||
|
||||
@test_util.disable_tfrt('Partial is not supported')
|
||||
def test_functools_partial_keywords(self):
|
||||
def f(x, y):
|
||||
return x + y
|
||||
@ -257,6 +269,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
|
||||
functools.partial(f, x=array_ops.zeros([1]), y=array_ops.zeros([1])))
|
||||
self.assertAllEqual(func(), [0.0])
|
||||
|
||||
@test_util.disable_tfrt('Partial is not supported')
|
||||
def test_functools_partial_single_positional(self):
|
||||
def f(x, y):
|
||||
return x + y
|
||||
@ -265,6 +278,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
|
||||
functools.partial(f, constant_op.constant(1)))
|
||||
self.assertAllEqual(func(5), 6)
|
||||
|
||||
@test_util.disable_tfrt('Partial is not supported')
|
||||
def test_complicated_partial_with_defaults(self):
|
||||
|
||||
def identity(*args):
|
||||
@ -312,6 +326,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
|
||||
(tensor_spec.TensorSpec(
|
||||
None, dtypes.float32, name='x'),))
|
||||
|
||||
@test_util.disable_tfrt('Variable argument is not supported')
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_variable_naming(self):
|
||||
class HasVars(module.Module):
|
||||
@ -382,6 +397,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
|
||||
'defined in another function or code block'):
|
||||
f(array_ops.zeros(shape=(8, 42, 3)))
|
||||
|
||||
@test_util.disable_tfrt('Control flow is not supported')
|
||||
def testRuntimeErrorNotSticky(self):
|
||||
|
||||
@def_function.function
|
||||
@ -486,6 +502,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
|
||||
constant_op.constant(3.),
|
||||
constant_op.constant(4.)))
|
||||
|
||||
@test_util.disable_tfrt('Variable argument is not supported')
|
||||
def testVariableCreatorScope(self):
|
||||
created_variables = []
|
||||
captured_variables = []
|
||||
@ -505,6 +522,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
|
||||
f()
|
||||
self.assertEqual(created_variables, captured_variables)
|
||||
|
||||
@test_util.disable_tfrt('Variable argument is not supported')
|
||||
def testVarAlreadyInitializedNoClobbering(self):
|
||||
v_holder = []
|
||||
|
||||
@ -522,6 +540,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
|
||||
add_var.get_concrete_function(constant_op.constant(2.))
|
||||
self.assertAllClose([13., 14.], add_var(constant_op.constant(2.)))
|
||||
|
||||
@test_util.disable_tfrt('Variable argument is not supported')
|
||||
def testSameVariableTwice(self):
|
||||
v = variables.Variable(1.0)
|
||||
|
||||
@ -531,6 +550,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
self.assertAllEqual(add(v, v), 2.0)
|
||||
|
||||
@test_util.disable_tfrt('Variable argument is not supported')
|
||||
def testVariableUpdate(self):
|
||||
v1 = variables.Variable(1.0)
|
||||
v2 = variables.Variable(2.0)
|
||||
@ -566,6 +586,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
self.assertIs(func_a, func_b)
|
||||
|
||||
@test_util.disable_tfrt('Nested function is not supported')
|
||||
def testInitializationInNestedCall(self):
|
||||
v_holder = []
|
||||
|
||||
@ -588,6 +609,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
|
||||
v_holder[1].assign(11.)
|
||||
self.assertAllClose([14., 15.], wrapper(constant_op.constant(2.)))
|
||||
|
||||
@test_util.disable_tfrt('Variable argument is not supported')
|
||||
@test_util.run_gpu_only
|
||||
def testDeviceAnnotationRespected(self):
|
||||
a = []
|
||||
@ -607,6 +629,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
|
||||
create_variable()
|
||||
self.assertRegexpMatches(a[0].device, 'CPU')
|
||||
|
||||
@test_util.disable_tfrt('Variable argument is not supported')
|
||||
@test_util.run_gpu_only
|
||||
def testDeviceAnnotationForInitializerRespected(self):
|
||||
a = []
|
||||
@ -681,6 +704,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertEqual(self.evaluate(cloned(x)),
|
||||
self.evaluate(cloned_py_function(x)))
|
||||
|
||||
@test_util.disable_tfrt('Variable argument is not supported')
|
||||
def testLiftPlaceholderInitializedVariable(self):
|
||||
with ops.Graph().as_default():
|
||||
var_list = []
|
||||
@ -834,6 +858,7 @@ class DefFunctionTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertLen(logs.output, 1)
|
||||
self.assertIn('Tracing is expensive', logs.output[0])
|
||||
|
||||
@test_util.disable_tfrt('Nested function is not supported')
|
||||
def test_frequent_retracing_warning_nested(self):
|
||||
if sys.version_info[0] < 3:
|
||||
self.skipTest('self.assertLogs() call is not available in Python 2.')
|
||||
|
Loading…
Reference in New Issue
Block a user