From 896635c9cb1e794e8bd26b23beb9311bd03a2b2b Mon Sep 17 00:00:00 2001 From: Saurabh Saxena Date: Thu, 10 Sep 2020 12:31:56 -0700 Subject: [PATCH] [C++ gradients] Add experimental python bindings for unified API's Tape. See unified_api_test for usage. PiperOrigin-RevId: 330992900 Change-Id: I1e7b2fc3ce6ab7d53c8f05a9807e7a03f6e59715 --- tensorflow/c/eager/BUILD | 33 ++------ tensorflow/c/eager/mnist_gradients_testutil.h | 4 + tensorflow/c/experimental/gradients/BUILD | 36 ++++++++- tensorflow/c/experimental/ops/BUILD | 5 ++ tensorflow/python/BUILD | 7 ++ .../python/framework/experimental/BUILD | 46 ++++++++++- .../experimental/gradient_registry.py | 27 +++++++ .../python/framework/experimental/math_ops.cc | 25 +++++- .../python/framework/experimental/math_ops.py | 6 +- .../python/framework/experimental/tape.cc | 79 +++++++++++++++++++ .../python/framework/experimental/tape.py | 54 +++++++++++++ .../framework/experimental/tape_stack.py | 35 ++++++++ .../experimental/unified_api_test.py | 33 ++++++++ 13 files changed, 353 insertions(+), 37 deletions(-) create mode 100644 tensorflow/python/framework/experimental/gradient_registry.py create mode 100644 tensorflow/python/framework/experimental/tape.cc create mode 100644 tensorflow/python/framework/experimental/tape.py create mode 100644 tensorflow/python/framework/experimental/tape_stack.py diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 5f4d6720ac3..c095f12e0ba 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -103,9 +103,13 @@ filegroup( "c_api_unified_experimental.h", "c_api_unified_experimental_internal.h", "dlpack.h", + "gradients.h", + "gradients_internal.h", "immediate_execution_context.h", "immediate_execution_operation.h", "immediate_execution_tensor_handle.h", + "mnist_gradients_testutil.h", + "tape.h", "tfe_cancellation_manager_internal.h", "tfe_executor_internal.h", "tfe_monitoring_internal.h", @@ -164,31 +168,6 @@ cc_library( ], ) -cc_library( - name = "gradients", - srcs = [ - "gradients.cc", - "gradients_internal.h", - ], - hdrs = [ - "gradients.h", - ], - visibility = [ - "//tensorflow:internal", - ], - deps = [ - ":abstract_context", - ":abstract_operation", - ":abstract_tensor_handle", - ":c_api_unified_internal", - ":tape", - "//tensorflow/core/common_runtime/eager:attr_builder", - "//tensorflow/core/lib/llvm_rtti", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/strings", - ], -) - cc_library( name = "gradients_internal", srcs = [ @@ -266,9 +245,7 @@ cc_library( ":gradients_internal", "//tensorflow/c:tf_status_helper", "//tensorflow/c:tf_tensor", - "//tensorflow/c/experimental/ops:array_ops", - "//tensorflow/c/experimental/ops:math_ops", - "//tensorflow/c/experimental/ops:nn_ops", + "//tensorflow/c/experimental/ops", "//tensorflow/core/lib/llvm_rtti", "//tensorflow/core/platform:status", "@com_google_absl//absl/container:flat_hash_set", diff --git a/tensorflow/c/eager/mnist_gradients_testutil.h b/tensorflow/c/eager/mnist_gradients_testutil.h index efe196e9ba3..3407a83361a 100644 --- a/tensorflow/c/eager/mnist_gradients_testutil.h +++ b/tensorflow/c/eager/mnist_gradients_testutil.h @@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_MNIST_GRADIENTS_TESTUTIL_H_ +#define TENSORFLOW_C_EAGER_MNIST_GRADIENTS_TESTUTIL_H_ #include #include "absl/types/span.h" @@ -148,3 +150,5 @@ Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx); } // namespace internal } // namespace gradients } // namespace tensorflow + +#endif // TENSORFLOW_C_EAGER_MNIST_GRADIENTS_TESTUTIL_H_ diff --git a/tensorflow/c/experimental/gradients/BUILD b/tensorflow/c/experimental/gradients/BUILD index 36a3251def7..53423edd0a7 100644 --- a/tensorflow/c/experimental/gradients/BUILD +++ b/tensorflow/c/experimental/gradients/BUILD @@ -16,7 +16,7 @@ cc_library( "//tensorflow/c/eager:abstract_operation", "//tensorflow/c/eager:abstract_tensor_handle", "//tensorflow/c/eager:c_api_unified_internal", - "//tensorflow/c/eager:gradients", + "//tensorflow/c/eager:gradients_internal", "//tensorflow/core/lib/llvm_rtti", ], ) @@ -34,7 +34,7 @@ cc_library( "//tensorflow/c/eager:abstract_operation", "//tensorflow/c/eager:abstract_tensor_handle", "//tensorflow/c/eager:c_api_unified_internal", - "//tensorflow/c/eager:gradients", + "//tensorflow/c/eager:gradients_internal", "//tensorflow/c/experimental/ops:array_ops", "//tensorflow/c/experimental/ops:math_ops", "//tensorflow/c/experimental/ops:nn_ops", @@ -55,10 +55,40 @@ cc_library( "//tensorflow/c/eager:abstract_operation", "//tensorflow/c/eager:abstract_tensor_handle", "//tensorflow/c/eager:c_api_unified_internal", - "//tensorflow/c/eager:gradients", + "//tensorflow/c/eager:gradients_internal", "//tensorflow/c/experimental/ops:array_ops", "//tensorflow/c/experimental/ops:math_ops", "//tensorflow/c/experimental/ops:nn_ops", "//tensorflow/core/lib/llvm_rtti", ], ) + +cc_library( + name = "gradients", + hdrs = [ + "array_grad.h", + "math_grad.h", + "nn_grad.h", + ], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + ":array_grad", + ":math_grad", + ":nn_grad", + ], +) + +filegroup( + name = "pywrap_required_hdrs", + srcs = [ + "array_grad.h", + "math_grad.h", + "nn_grad.h", + ], + visibility = [ + "//tensorflow/core:__pkg__", + "//tensorflow/python:__pkg__", + ], +) diff --git a/tensorflow/c/experimental/ops/BUILD b/tensorflow/c/experimental/ops/BUILD index 6a8843595db..9c8dbfcbfb8 100644 --- a/tensorflow/c/experimental/ops/BUILD +++ b/tensorflow/c/experimental/ops/BUILD @@ -68,6 +68,11 @@ cc_library( cc_library( name = "ops", + hdrs = [ + "array_ops.h", + "math_ops.h", + "nn_ops.h", + ], visibility = [ "//tensorflow:internal", ], diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 9e0eacf7b77..a5f29a5735e 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -682,6 +682,7 @@ tf_python_pybind_extension( "//tensorflow/c:headers", "//tensorflow/c/eager:headers", "//tensorflow/c/eager:pywrap_required_hdrs", + "//tensorflow/c/experimental/ops:pywrap_required_hdrs", "//tensorflow/core/common_runtime/eager:pywrap_required_hdrs", "//tensorflow/core/distributed_runtime:pywrap_required_hdrs", "//tensorflow/core/distributed_runtime/eager:pywrap_required_hdrs", @@ -702,6 +703,7 @@ tf_python_pybind_extension( "//tensorflow/core/common_runtime:core_cpu_headers_lib", "//tensorflow/core:lib_headers_for_pybind", "@com_google_absl//absl/types:optional", + "//tensorflow/core/lib/llvm_rtti", ] + if_static( extra_deps = [ "//tensorflow/core/protobuf:eager_service_proto_cc", @@ -5939,6 +5941,8 @@ pywrap_tensorflow_macro( "//tensorflow/c/eager:c_api", "//tensorflow/c/eager:c_api_experimental", "//tensorflow/c/experimental/ops", + "//tensorflow/c/experimental/gradients", + "//tensorflow/c/eager:mnist_gradients_testutil", "//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration", "//tensorflow/core/data/service:server_lib", "//tensorflow/core/distributed_runtime/rpc:grpc_rpc_factory_registration", @@ -7448,6 +7452,7 @@ cc_library( "//tensorflow/c/eager:headers", "//tensorflow/c/eager:pywrap_required_hdrs", "//tensorflow/c/experimental/ops:pywrap_required_hdrs", + "//tensorflow/c/experimental/gradients:pywrap_required_hdrs", "//tensorflow/core/common_runtime/eager:pywrap_required_hdrs", "//tensorflow/python/eager:pywrap_required_hdrs", "//tensorflow/core/common_runtime:core_cpu_lib_headers", @@ -7484,6 +7489,7 @@ tf_python_pybind_extension( "//tensorflow/c:headers", "//tensorflow/c/eager:headers", "//tensorflow/c/eager:pywrap_required_hdrs", + "//tensorflow/c/experimental/ops:pywrap_required_hdrs", "//tensorflow/core/common_runtime/eager:pywrap_required_hdrs", "//tensorflow/core/distributed_runtime:pywrap_required_hdrs", "//tensorflow/core/distributed_runtime/eager:pywrap_required_hdrs", @@ -7511,6 +7517,7 @@ tf_python_pybind_extension( "//tensorflow/core:lib_headers_for_pybind", "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform", + "//tensorflow/core/lib/llvm_rtti", ] + if_static( extra_deps = [ "//tensorflow/core/protobuf:eager_service_proto_cc", diff --git a/tensorflow/python/framework/experimental/BUILD b/tensorflow/python/framework/experimental/BUILD index 4b7f4fd2b04..bedd5b2945e 100644 --- a/tensorflow/python/framework/experimental/BUILD +++ b/tensorflow/python/framework/experimental/BUILD @@ -27,6 +27,22 @@ tf_python_pybind_extension( ], ) +tf_python_pybind_extension( + name = "_tape", + srcs = ["tape.cc"], + features = ["-layering_check"], + module_name = "_tape", + deps = [ + "//tensorflow/c/eager:tfe_tensorhandle_internal", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/lib/llvm_rtti", + "//tensorflow/python:pybind11_lib", + "//tensorflow/python:unified_api_pywrap_required_headers", + "@pybind11", + ], +) + tf_python_pybind_extension( name = "_math_ops", srcs = ["math_ops.cc"], @@ -44,10 +60,31 @@ tf_python_pybind_extension( ], ) +py_library( + name = "gradient_registry", + srcs = ["gradient_registry.py"], + deps = [":_tape"], +) + py_library( name = "math_ops", srcs = ["math_ops.py"], - deps = ["_math_ops"], + deps = [ + ":_math_ops", + ":gradient_registry", + ":tape_stack", + ], +) + +py_library( + name = "tape", + srcs = ["tape.py"], + deps = [ + ":_tape", + ":context_stack", + ":tape_stack", + "//tensorflow/python/data/util:nest", + ], ) py_library( @@ -66,6 +103,12 @@ py_library( deps = [":thread_local_stack"], ) +py_library( + name = "tape_stack", + srcs = ["tape_stack.py"], + deps = [":thread_local_stack"], +) + cuda_py_test( name = "unified_api_test", size = "small", @@ -82,6 +125,7 @@ cuda_py_test( ":context_stack", ":def_function", ":math_ops", + ":tape", "//tensorflow/python:client_testlib", "@absl_py//absl/testing:parameterized", ], diff --git a/tensorflow/python/framework/experimental/gradient_registry.py b/tensorflow/python/framework/experimental/gradient_registry.py new file mode 100644 index 00000000000..6cba4001553 --- /dev/null +++ b/tensorflow/python/framework/experimental/gradient_registry.py @@ -0,0 +1,27 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Global GradientRegistry.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework.experimental import _tape + +_GRADIENT_REGISTRY_GLOBAL = _tape.GradientRegistry() + + +def get_global_registry(): + return _GRADIENT_REGISTRY_GLOBAL diff --git a/tensorflow/python/framework/experimental/math_ops.cc b/tensorflow/python/framework/experimental/math_ops.cc index cbf2568b607..a0ca1c1bb89 100644 --- a/tensorflow/python/framework/experimental/math_ops.cc +++ b/tensorflow/python/framework/experimental/math_ops.cc @@ -23,22 +23,39 @@ limitations under the License. #include "pybind11/pybind11.h" #include "tensorflow/c/eager/abstract_context.h" #include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/gradients.h" #include "tensorflow/python/lib/core/pybind11_status.h" +// This library provides helpers for running ops and recording them on a +// GradientTape. This is currently needed because the tape does not provide +// an implementation of the abstract execution APIs but that will change. +// TODO(b/168209775): Remove this and its imported symbols once the tape +// execution context is ready. +#include "tensorflow/c/eager/mnist_gradients_testutil.h" + using tensorflow::AbstractContext; using tensorflow::AbstractTensorHandle; -using tensorflow::ops::Add; +using tensorflow::gradients::GradientRegistry; +using tensorflow::gradients::Tape; +namespace tensorflow { PYBIND11_MODULE(_math_ops, m) { m.def("add", [](AbstractContext* ctx, AbstractTensorHandle* a, - AbstractTensorHandle* b, const char* name) { + AbstractTensorHandle* b, const char* name, Tape* tape, + GradientRegistry* registry) { int num_outputs = 1; std::vector outputs(1); if (!name) { name = "Add"; } - MaybeRaiseRegisteredFromStatus( - Add(ctx, {a, b}, absl::MakeSpan(outputs), name)); + if (!tape) { + MaybeRaiseRegisteredFromStatus( + ops::Add(ctx, {a, b}, absl::MakeSpan(outputs), name)); + } else { + MaybeRaiseRegisteredFromStatus(gradients::internal::Add( + ctx, tape, {a, b}, absl::MakeSpan(outputs), *registry)); + } return outputs[0]; }); } +} // namespace tensorflow diff --git a/tensorflow/python/framework/experimental/math_ops.py b/tensorflow/python/framework/experimental/math_ops.py index c03dcc6e082..2aefe3aa574 100644 --- a/tensorflow/python/framework/experimental/math_ops.py +++ b/tensorflow/python/framework/experimental/math_ops.py @@ -20,8 +20,12 @@ from __future__ import print_function from tensorflow.python.framework.experimental import _math_ops from tensorflow.python.framework.experimental import context_stack as context +from tensorflow.python.framework.experimental import gradient_registry +from tensorflow.python.framework.experimental import tape_stack def add(a, b, name=None): ctx = context.get_default() - return _math_ops.add(ctx, a, b, name) + tape = tape_stack.get_default() + grad_registry = gradient_registry.get_global_registry() + return _math_ops.add(ctx, a, b, name, tape, grad_registry) diff --git a/tensorflow/python/framework/experimental/tape.cc b/tensorflow/python/framework/experimental/tape.cc new file mode 100644 index 00000000000..003a70141dd --- /dev/null +++ b/tensorflow/python/framework/experimental/tape.cc @@ -0,0 +1,79 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "pybind11/pybind11.h" +#include "tensorflow/c/eager/gradients.h" +#include "tensorflow/c/experimental/gradients/math_grad.h" +#include "tensorflow/c/experimental/gradients/nn_grad.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/python/lib/core/pybind11_status.h" + +namespace py = pybind11; + +namespace tensorflow { +namespace gradients { + +Status RegisterGradients(GradientRegistry* registry) { + TF_RETURN_IF_ERROR(registry->Register("Add", AddRegisterer)); + TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer)); + TF_RETURN_IF_ERROR(registry->Register("MatMul", MatMulRegisterer)); + TF_RETURN_IF_ERROR(registry->Register("Relu", ReluRegisterer)); + TF_RETURN_IF_ERROR( + registry->Register("SparseSoftmaxCrossEntropyWithLogits", + SparseSoftmaxCrossEntropyLossRegisterer)); + return Status::OK(); +} + +PYBIND11_MODULE(_tape, m) { + py::class_(m, "Tape") + .def(py::init([](bool persistent) { return new Tape(persistent); })) + .def("Watch", + [](Tape* self, AbstractTensorHandle* t) { self->Watch(ToId(t)); }) + .def("ComputeGradient", + [](Tape* self, TapeVSpace* vspace, + std::vector target_tensors, + std::vector source_tensors, + std::vector output_gradients) { + std::vector target_tensor_ids; + std::vector source_tensor_ids; + target_tensor_ids.reserve(target_tensors.size()); + source_tensor_ids.reserve(source_tensors.size()); + for (auto t : target_tensors) { + target_tensor_ids.emplace_back(ToId(t)); + } + for (auto t : source_tensors) { + source_tensor_ids.emplace_back(ToId(t)); + } + std::unordered_map + source_tensors_that_are_targets; + std::vector results; + Status s = self->ComputeGradient( + *vspace, target_tensor_ids, source_tensor_ids, + source_tensors_that_are_targets, output_gradients, &results, + /*build_default_zeros_grads=*/false); + MaybeRaiseRegisteredFromStatus(s); + return results; + }); + py::class_(m, "TapeVSpace") + .def(py::init([](AbstractContext* ctx) { return new TapeVSpace(ctx); })); + py::class_(m, "GradientRegistry").def(py::init([]() { + auto registry = new GradientRegistry(); + MaybeRaiseRegisteredFromStatus(RegisterGradients(registry)); + return registry; + })); +} +} // namespace gradients +} // namespace tensorflow diff --git a/tensorflow/python/framework/experimental/tape.py b/tensorflow/python/framework/experimental/tape.py new file mode 100644 index 00000000000..47ce781ee7c --- /dev/null +++ b/tensorflow/python/framework/experimental/tape.py @@ -0,0 +1,54 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Experimental impl for GradientTape using unified APIs, for testing only.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework.experimental import _tape +from tensorflow.python.framework.experimental import context_stack +from tensorflow.python.framework.experimental import tape_stack +from tensorflow.python.util import nest + + +class GradientTape(object): + """GradientTape using the unified API.""" + + def __init__(self, persistent=False): + self._c_tape = _tape.Tape(persistent) + + def watch(self, t): + self._c_tape.Watch(t) + + # TODO(srbs): Add support for unconnected_gradients. + def gradient(self, targets, sources, output_gradients=None): + ctx = context_stack.get_default() + vspace = _tape.TapeVSpace(ctx) + flat_targets = nest.flatten(targets) + flat_sources = nest.flatten(sources) + out_grads = self._c_tape.ComputeGradient(vspace, flat_targets, flat_sources, + output_gradients or []) + return nest.pack_sequence_as(sources, out_grads) + + def __enter__(self): + """Enters a context inside which operations are recorded on this tape.""" + if tape_stack.get_default(): + raise ValueError("Nested tapes are not supported yet.") + tape_stack.push(self._c_tape) + return self + + def __exit__(self, typ, value, traceback): + tape_stack.pop() diff --git a/tensorflow/python/framework/experimental/tape_stack.py b/tensorflow/python/framework/experimental/tape_stack.py new file mode 100644 index 00000000000..20583e699bb --- /dev/null +++ b/tensorflow/python/framework/experimental/tape_stack.py @@ -0,0 +1,35 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Thread-local context manager stack.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework.experimental import thread_local_stack + +_default_tape_stack = thread_local_stack.ThreadLocalStack() + + +def get_default(): + return _default_tape_stack.peek() + + +def push(tape): + _default_tape_stack.push(tape) + + +def pop(): + _default_tape_stack.pop() diff --git a/tensorflow/python/framework/experimental/unified_api_test.py b/tensorflow/python/framework/experimental/unified_api_test.py index 88cd7657a67..8587c1aa3fd 100644 --- a/tensorflow/python/framework/experimental/unified_api_test.py +++ b/tensorflow/python/framework/experimental/unified_api_test.py @@ -24,6 +24,7 @@ from tensorflow.python.framework.experimental import _unified_api from tensorflow.python.framework.experimental import context_stack as context_lib from tensorflow.python.framework.experimental import def_function from tensorflow.python.framework.experimental import math_ops +from tensorflow.python.framework.experimental import tape as tape_lib from tensorflow.python.platform import test NewImmediateExecutionContext = _unified_api.NewImmediateExecutionContext @@ -57,6 +58,38 @@ class UnifiedApiTest(test.TestCase, parameterized.TestCase): eager_output = model(a, b) self.assertAllEqual(eager_output.numpy(), 3.0) + @parameterized.named_parameters([ + ("EagerGraph", False, False), + ("EagerMlir", False, True), + # TODO(srbs): Enable for TFRT. Segfaults right now. + # ("TfrtGraph", True, False), + # ("TfrtMlir", True, True), + ]) + def testAddGrad(self, use_tfrt, use_mlir): + if use_mlir: + SetTracingImplementation("mlir") + + def model(a, b): + with tape_lib.GradientTape() as tape: + tape.watch(a) + tape.watch(b) + result = math_ops.add(a, b) + grads = tape.gradient(result, [a, b]) + return grads + + eager_ctx = NewImmediateExecutionContext(use_tfrt) + with context_lib.set_default(eager_ctx): + a = eager_ctx.CreateFloatScalarHandle(1.) + b = eager_ctx.CreateFloatScalarHandle(2.) + + func_outputs = def_function.function(model)(a, b) + self.assertAllEqual(func_outputs[0].numpy(), 1.0) + self.assertAllEqual(func_outputs[1].numpy(), 1.0) + + eager_outputs = model(a, b) + self.assertAllEqual(eager_outputs[0].numpy(), 1.0) + self.assertAllEqual(eager_outputs[1].numpy(), 1.0) + if __name__ == "__main__": test.main()