Make TFE_TensorHandleCache aware of TFE_Context.
Some tests reset the global context, and TensorHandleCache was incorrectly shared across different context before. PiperOrigin-RevId: 334870644 Change-Id: I6159ca89d67b3939e98d7b43ce847ac205abf650
This commit is contained in:
parent
0fbd189a50
commit
f16b71ccf0
tensorflow/python/eager
@ -303,9 +303,7 @@ class OpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
try:
|
||||
config.set_device_policy('silent')
|
||||
config.set_soft_device_placement(True)
|
||||
# Avoid the TensorHandle cache hit.
|
||||
# TODO(b/169790439): include Context to the TensorHandle cache.
|
||||
cpu_tensor = constant_op.constant(1.1)
|
||||
cpu_tensor = constant_op.constant(1.0)
|
||||
result = cpu_tensor + cpu_tensor
|
||||
self.assertEqual(result.device,
|
||||
'/job:localhost/replica:0/task:0/device:GPU:0')
|
||||
@ -504,6 +502,18 @@ class OpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
context.async_clear_error()
|
||||
config.set_synchronous_execution(True)
|
||||
|
||||
def testCrossContextTensorCache(self):
|
||||
old_context = context.context()
|
||||
old_x = constant_op.constant(9.5)
|
||||
context._set_context(context.Context())
|
||||
|
||||
try:
|
||||
new_x = constant_op.constant(9.5)
|
||||
self.assertEqual(new_x.numpy(), 9.5)
|
||||
finally:
|
||||
context._set_context(old_context)
|
||||
|
||||
self.assertEqual(old_x.numpy(), 9.5)
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -335,12 +335,12 @@ TFE_TensorHandle* ConvertToEagerTensor(TFE_Context* ctx, PyObject* value,
|
||||
// TODO(slebedev): also cache singleton NumPy arrays and scalars?
|
||||
if (PyArray_IsPythonNumber(value)) {
|
||||
auto* cache = TFE_TensorHandleCache::Get();
|
||||
TFE_TensorHandle* handle = cache->Lookup(value, dtype, device_name);
|
||||
TFE_TensorHandle* handle = cache->Lookup(value, dtype, ctx, device_name);
|
||||
if (handle != nullptr) return handle;
|
||||
handle = ConvertToEagerTensorUncached(ctx, value, dtype, device_name);
|
||||
if (handle == nullptr) return nullptr;
|
||||
if (!PyFloat_Check(value) || std::isfinite(PyFloat_AS_DOUBLE(value))) {
|
||||
cache->Insert(value, dtype, device_name, handle);
|
||||
cache->Insert(value, dtype, ctx, device_name, handle);
|
||||
}
|
||||
return handle;
|
||||
} else {
|
||||
|
@ -38,10 +38,10 @@ TFE_TensorHandleCache* TFE_TensorHandleCache::Get() {
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TFE_TensorHandleCache::Lookup(
|
||||
PyObject* value, tensorflow::DataType dtype,
|
||||
PyObject* value, tensorflow::DataType dtype, TFE_Context* ctx,
|
||||
absl::string_view device_name) const {
|
||||
CHECK_NOTNULL(value);
|
||||
const auto it = cache.find(Key{PyObjectPtr{value}, dtype, device_name});
|
||||
const auto it = cache.find(Key{PyObjectPtr{value}, dtype, ctx, device_name});
|
||||
if (it == cache.end()) {
|
||||
scalar_cache_misses->GetCell()->IncrementBy(1);
|
||||
return nullptr;
|
||||
@ -53,10 +53,11 @@ TFE_TensorHandle* TFE_TensorHandleCache::Lookup(
|
||||
}
|
||||
|
||||
void TFE_TensorHandleCache::Insert(PyObject* value, tensorflow::DataType dtype,
|
||||
TFE_Context* ctx,
|
||||
absl::string_view device_name,
|
||||
TFE_TensorHandle* h) {
|
||||
Py_INCREF(value);
|
||||
cache.emplace(Key{PyObjectPtr{value}, dtype, device_name},
|
||||
cache.emplace(Key{PyObjectPtr{value}, dtype, ctx, device_name},
|
||||
tensorflow::wrap(tensorflow::unwrap(h)->Copy()));
|
||||
}
|
||||
|
||||
|
@ -73,16 +73,20 @@ struct TFE_TensorHandleCache {
|
||||
~TFE_TensorHandleCache() { DecrefUnrefAll(); }
|
||||
|
||||
TFE_TensorHandle* Lookup(PyObject* value, tensorflow::DataType dtype,
|
||||
TFE_Context* ctx,
|
||||
absl::string_view device_name) const;
|
||||
|
||||
void Insert(PyObject* value, tensorflow::DataType dtype,
|
||||
void Insert(PyObject* value, tensorflow::DataType dtype, TFE_Context* ctx,
|
||||
absl::string_view device_name, TFE_TensorHandle* h);
|
||||
|
||||
void Clear();
|
||||
|
||||
private:
|
||||
// TODO(slebedev): should the key depend on TFE_Context?
|
||||
using Key = std::tuple<PyObjectPtr, tensorflow::DataType, absl::string_view>;
|
||||
// TODO(kkb): Instead of `TFE_Context*` key, ideally Python's context object
|
||||
// should have TFE_TensorHandleCache instance. Migrate once we Python context
|
||||
// object is backed by C++ data structure. b/169790439
|
||||
using Key = std::tuple<PyObjectPtr, tensorflow::DataType, TFE_Context*,
|
||||
absl::string_view>;
|
||||
|
||||
void DecrefUnrefAll() {
|
||||
for (const auto& p : cache) {
|
||||
|
Loading…
Reference in New Issue
Block a user