From f16b71ccf0a137e6c2b4fbee78606116852bd97b Mon Sep 17 00:00:00 2001 From: Kibeom Kim Date: Thu, 1 Oct 2020 12:18:19 -0700 Subject: [PATCH] Make TFE_TensorHandleCache aware of TFE_Context. Some tests reset the global context, and TensorHandleCache was incorrectly shared across different context before. PiperOrigin-RevId: 334870644 Change-Id: I6159ca89d67b3939e98d7b43ce847ac205abf650 --- tensorflow/python/eager/ops_test.py | 16 +++++++++++++--- tensorflow/python/eager/pywrap_tensor.cc | 4 ++-- .../python/eager/pywrap_tensor_conversion.cc | 7 ++++--- .../python/eager/pywrap_tensor_conversion.h | 10 +++++++--- 4 files changed, 26 insertions(+), 11 deletions(-) diff --git a/tensorflow/python/eager/ops_test.py b/tensorflow/python/eager/ops_test.py index f7f44b83ea1..0c8bbe76c98 100644 --- a/tensorflow/python/eager/ops_test.py +++ b/tensorflow/python/eager/ops_test.py @@ -303,9 +303,7 @@ class OpsTest(test_util.TensorFlowTestCase, parameterized.TestCase): try: config.set_device_policy('silent') config.set_soft_device_placement(True) - # Avoid the TensorHandle cache hit. - # TODO(b/169790439): include Context to the TensorHandle cache. - cpu_tensor = constant_op.constant(1.1) + cpu_tensor = constant_op.constant(1.0) result = cpu_tensor + cpu_tensor self.assertEqual(result.device, '/job:localhost/replica:0/task:0/device:GPU:0') @@ -504,6 +502,18 @@ class OpsTest(test_util.TensorFlowTestCase, parameterized.TestCase): context.async_clear_error() config.set_synchronous_execution(True) + def testCrossContextTensorCache(self): + old_context = context.context() + old_x = constant_op.constant(9.5) + context._set_context(context.Context()) + + try: + new_x = constant_op.constant(9.5) + self.assertEqual(new_x.numpy(), 9.5) + finally: + context._set_context(old_context) + + self.assertEqual(old_x.numpy(), 9.5) if __name__ == '__main__': test.main() diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc index e5c74deaf80..bdd17c889e6 100644 --- a/tensorflow/python/eager/pywrap_tensor.cc +++ b/tensorflow/python/eager/pywrap_tensor.cc @@ -335,12 +335,12 @@ TFE_TensorHandle* ConvertToEagerTensor(TFE_Context* ctx, PyObject* value, // TODO(slebedev): also cache singleton NumPy arrays and scalars? if (PyArray_IsPythonNumber(value)) { auto* cache = TFE_TensorHandleCache::Get(); - TFE_TensorHandle* handle = cache->Lookup(value, dtype, device_name); + TFE_TensorHandle* handle = cache->Lookup(value, dtype, ctx, device_name); if (handle != nullptr) return handle; handle = ConvertToEagerTensorUncached(ctx, value, dtype, device_name); if (handle == nullptr) return nullptr; if (!PyFloat_Check(value) || std::isfinite(PyFloat_AS_DOUBLE(value))) { - cache->Insert(value, dtype, device_name, handle); + cache->Insert(value, dtype, ctx, device_name, handle); } return handle; } else { diff --git a/tensorflow/python/eager/pywrap_tensor_conversion.cc b/tensorflow/python/eager/pywrap_tensor_conversion.cc index 041ddf4ec53..432082433d7 100644 --- a/tensorflow/python/eager/pywrap_tensor_conversion.cc +++ b/tensorflow/python/eager/pywrap_tensor_conversion.cc @@ -38,10 +38,10 @@ TFE_TensorHandleCache* TFE_TensorHandleCache::Get() { } TFE_TensorHandle* TFE_TensorHandleCache::Lookup( - PyObject* value, tensorflow::DataType dtype, + PyObject* value, tensorflow::DataType dtype, TFE_Context* ctx, absl::string_view device_name) const { CHECK_NOTNULL(value); - const auto it = cache.find(Key{PyObjectPtr{value}, dtype, device_name}); + const auto it = cache.find(Key{PyObjectPtr{value}, dtype, ctx, device_name}); if (it == cache.end()) { scalar_cache_misses->GetCell()->IncrementBy(1); return nullptr; @@ -53,10 +53,11 @@ TFE_TensorHandle* TFE_TensorHandleCache::Lookup( } void TFE_TensorHandleCache::Insert(PyObject* value, tensorflow::DataType dtype, + TFE_Context* ctx, absl::string_view device_name, TFE_TensorHandle* h) { Py_INCREF(value); - cache.emplace(Key{PyObjectPtr{value}, dtype, device_name}, + cache.emplace(Key{PyObjectPtr{value}, dtype, ctx, device_name}, tensorflow::wrap(tensorflow::unwrap(h)->Copy())); } diff --git a/tensorflow/python/eager/pywrap_tensor_conversion.h b/tensorflow/python/eager/pywrap_tensor_conversion.h index 8890979c379..5bd851ed5b2 100644 --- a/tensorflow/python/eager/pywrap_tensor_conversion.h +++ b/tensorflow/python/eager/pywrap_tensor_conversion.h @@ -73,16 +73,20 @@ struct TFE_TensorHandleCache { ~TFE_TensorHandleCache() { DecrefUnrefAll(); } TFE_TensorHandle* Lookup(PyObject* value, tensorflow::DataType dtype, + TFE_Context* ctx, absl::string_view device_name) const; - void Insert(PyObject* value, tensorflow::DataType dtype, + void Insert(PyObject* value, tensorflow::DataType dtype, TFE_Context* ctx, absl::string_view device_name, TFE_TensorHandle* h); void Clear(); private: - // TODO(slebedev): should the key depend on TFE_Context? - using Key = std::tuple; + // TODO(kkb): Instead of `TFE_Context*` key, ideally Python's context object + // should have TFE_TensorHandleCache instance. Migrate once we Python context + // object is backed by C++ data structure. b/169790439 + using Key = std::tuple; void DecrefUnrefAll() { for (const auto& p : cache) {