Make TFE_TensorHandleCache aware of TFE_Context.

Some tests reset the global context, and TensorHandleCache was incorrectly shared across different context before.

PiperOrigin-RevId: 334870644
Change-Id: I6159ca89d67b3939e98d7b43ce847ac205abf650
This commit is contained in:
Kibeom Kim 2020-10-01 12:18:19 -07:00 committed by TensorFlower Gardener
parent 0fbd189a50
commit f16b71ccf0
4 changed files with 26 additions and 11 deletions

View File

@ -303,9 +303,7 @@ class OpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
try: try:
config.set_device_policy('silent') config.set_device_policy('silent')
config.set_soft_device_placement(True) config.set_soft_device_placement(True)
# Avoid the TensorHandle cache hit. cpu_tensor = constant_op.constant(1.0)
# TODO(b/169790439): include Context to the TensorHandle cache.
cpu_tensor = constant_op.constant(1.1)
result = cpu_tensor + cpu_tensor result = cpu_tensor + cpu_tensor
self.assertEqual(result.device, self.assertEqual(result.device,
'/job:localhost/replica:0/task:0/device:GPU:0') '/job:localhost/replica:0/task:0/device:GPU:0')
@ -504,6 +502,18 @@ class OpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
context.async_clear_error() context.async_clear_error()
config.set_synchronous_execution(True) config.set_synchronous_execution(True)
def testCrossContextTensorCache(self):
old_context = context.context()
old_x = constant_op.constant(9.5)
context._set_context(context.Context())
try:
new_x = constant_op.constant(9.5)
self.assertEqual(new_x.numpy(), 9.5)
finally:
context._set_context(old_context)
self.assertEqual(old_x.numpy(), 9.5)
if __name__ == '__main__': if __name__ == '__main__':
test.main() test.main()

View File

@ -335,12 +335,12 @@ TFE_TensorHandle* ConvertToEagerTensor(TFE_Context* ctx, PyObject* value,
// TODO(slebedev): also cache singleton NumPy arrays and scalars? // TODO(slebedev): also cache singleton NumPy arrays and scalars?
if (PyArray_IsPythonNumber(value)) { if (PyArray_IsPythonNumber(value)) {
auto* cache = TFE_TensorHandleCache::Get(); auto* cache = TFE_TensorHandleCache::Get();
TFE_TensorHandle* handle = cache->Lookup(value, dtype, device_name); TFE_TensorHandle* handle = cache->Lookup(value, dtype, ctx, device_name);
if (handle != nullptr) return handle; if (handle != nullptr) return handle;
handle = ConvertToEagerTensorUncached(ctx, value, dtype, device_name); handle = ConvertToEagerTensorUncached(ctx, value, dtype, device_name);
if (handle == nullptr) return nullptr; if (handle == nullptr) return nullptr;
if (!PyFloat_Check(value) || std::isfinite(PyFloat_AS_DOUBLE(value))) { if (!PyFloat_Check(value) || std::isfinite(PyFloat_AS_DOUBLE(value))) {
cache->Insert(value, dtype, device_name, handle); cache->Insert(value, dtype, ctx, device_name, handle);
} }
return handle; return handle;
} else { } else {

View File

@ -38,10 +38,10 @@ TFE_TensorHandleCache* TFE_TensorHandleCache::Get() {
} }
TFE_TensorHandle* TFE_TensorHandleCache::Lookup( TFE_TensorHandle* TFE_TensorHandleCache::Lookup(
PyObject* value, tensorflow::DataType dtype, PyObject* value, tensorflow::DataType dtype, TFE_Context* ctx,
absl::string_view device_name) const { absl::string_view device_name) const {
CHECK_NOTNULL(value); CHECK_NOTNULL(value);
const auto it = cache.find(Key{PyObjectPtr{value}, dtype, device_name}); const auto it = cache.find(Key{PyObjectPtr{value}, dtype, ctx, device_name});
if (it == cache.end()) { if (it == cache.end()) {
scalar_cache_misses->GetCell()->IncrementBy(1); scalar_cache_misses->GetCell()->IncrementBy(1);
return nullptr; return nullptr;
@ -53,10 +53,11 @@ TFE_TensorHandle* TFE_TensorHandleCache::Lookup(
} }
void TFE_TensorHandleCache::Insert(PyObject* value, tensorflow::DataType dtype, void TFE_TensorHandleCache::Insert(PyObject* value, tensorflow::DataType dtype,
TFE_Context* ctx,
absl::string_view device_name, absl::string_view device_name,
TFE_TensorHandle* h) { TFE_TensorHandle* h) {
Py_INCREF(value); Py_INCREF(value);
cache.emplace(Key{PyObjectPtr{value}, dtype, device_name}, cache.emplace(Key{PyObjectPtr{value}, dtype, ctx, device_name},
tensorflow::wrap(tensorflow::unwrap(h)->Copy())); tensorflow::wrap(tensorflow::unwrap(h)->Copy()));
} }

View File

@ -73,16 +73,20 @@ struct TFE_TensorHandleCache {
~TFE_TensorHandleCache() { DecrefUnrefAll(); } ~TFE_TensorHandleCache() { DecrefUnrefAll(); }
TFE_TensorHandle* Lookup(PyObject* value, tensorflow::DataType dtype, TFE_TensorHandle* Lookup(PyObject* value, tensorflow::DataType dtype,
TFE_Context* ctx,
absl::string_view device_name) const; absl::string_view device_name) const;
void Insert(PyObject* value, tensorflow::DataType dtype, void Insert(PyObject* value, tensorflow::DataType dtype, TFE_Context* ctx,
absl::string_view device_name, TFE_TensorHandle* h); absl::string_view device_name, TFE_TensorHandle* h);
void Clear(); void Clear();
private: private:
// TODO(slebedev): should the key depend on TFE_Context? // TODO(kkb): Instead of `TFE_Context*` key, ideally Python's context object
using Key = std::tuple<PyObjectPtr, tensorflow::DataType, absl::string_view>; // should have TFE_TensorHandleCache instance. Migrate once we Python context
// object is backed by C++ data structure. b/169790439
using Key = std::tuple<PyObjectPtr, tensorflow::DataType, TFE_Context*,
absl::string_view>;
void DecrefUnrefAll() { void DecrefUnrefAll() {
for (const auto& p : cache) { for (const auto& p : cache) {