Make TFE_TensorHandleCache aware of TFE_Context.
Some tests reset the global context, and TensorHandleCache was incorrectly shared across different context before. PiperOrigin-RevId: 334870644 Change-Id: I6159ca89d67b3939e98d7b43ce847ac205abf650
This commit is contained in:
parent
0fbd189a50
commit
f16b71ccf0
@ -303,9 +303,7 @@ class OpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
try:
|
try:
|
||||||
config.set_device_policy('silent')
|
config.set_device_policy('silent')
|
||||||
config.set_soft_device_placement(True)
|
config.set_soft_device_placement(True)
|
||||||
# Avoid the TensorHandle cache hit.
|
cpu_tensor = constant_op.constant(1.0)
|
||||||
# TODO(b/169790439): include Context to the TensorHandle cache.
|
|
||||||
cpu_tensor = constant_op.constant(1.1)
|
|
||||||
result = cpu_tensor + cpu_tensor
|
result = cpu_tensor + cpu_tensor
|
||||||
self.assertEqual(result.device,
|
self.assertEqual(result.device,
|
||||||
'/job:localhost/replica:0/task:0/device:GPU:0')
|
'/job:localhost/replica:0/task:0/device:GPU:0')
|
||||||
@ -504,6 +502,18 @@ class OpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
context.async_clear_error()
|
context.async_clear_error()
|
||||||
config.set_synchronous_execution(True)
|
config.set_synchronous_execution(True)
|
||||||
|
|
||||||
|
def testCrossContextTensorCache(self):
|
||||||
|
old_context = context.context()
|
||||||
|
old_x = constant_op.constant(9.5)
|
||||||
|
context._set_context(context.Context())
|
||||||
|
|
||||||
|
try:
|
||||||
|
new_x = constant_op.constant(9.5)
|
||||||
|
self.assertEqual(new_x.numpy(), 9.5)
|
||||||
|
finally:
|
||||||
|
context._set_context(old_context)
|
||||||
|
|
||||||
|
self.assertEqual(old_x.numpy(), 9.5)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -335,12 +335,12 @@ TFE_TensorHandle* ConvertToEagerTensor(TFE_Context* ctx, PyObject* value,
|
|||||||
// TODO(slebedev): also cache singleton NumPy arrays and scalars?
|
// TODO(slebedev): also cache singleton NumPy arrays and scalars?
|
||||||
if (PyArray_IsPythonNumber(value)) {
|
if (PyArray_IsPythonNumber(value)) {
|
||||||
auto* cache = TFE_TensorHandleCache::Get();
|
auto* cache = TFE_TensorHandleCache::Get();
|
||||||
TFE_TensorHandle* handle = cache->Lookup(value, dtype, device_name);
|
TFE_TensorHandle* handle = cache->Lookup(value, dtype, ctx, device_name);
|
||||||
if (handle != nullptr) return handle;
|
if (handle != nullptr) return handle;
|
||||||
handle = ConvertToEagerTensorUncached(ctx, value, dtype, device_name);
|
handle = ConvertToEagerTensorUncached(ctx, value, dtype, device_name);
|
||||||
if (handle == nullptr) return nullptr;
|
if (handle == nullptr) return nullptr;
|
||||||
if (!PyFloat_Check(value) || std::isfinite(PyFloat_AS_DOUBLE(value))) {
|
if (!PyFloat_Check(value) || std::isfinite(PyFloat_AS_DOUBLE(value))) {
|
||||||
cache->Insert(value, dtype, device_name, handle);
|
cache->Insert(value, dtype, ctx, device_name, handle);
|
||||||
}
|
}
|
||||||
return handle;
|
return handle;
|
||||||
} else {
|
} else {
|
||||||
|
@ -38,10 +38,10 @@ TFE_TensorHandleCache* TFE_TensorHandleCache::Get() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TFE_TensorHandle* TFE_TensorHandleCache::Lookup(
|
TFE_TensorHandle* TFE_TensorHandleCache::Lookup(
|
||||||
PyObject* value, tensorflow::DataType dtype,
|
PyObject* value, tensorflow::DataType dtype, TFE_Context* ctx,
|
||||||
absl::string_view device_name) const {
|
absl::string_view device_name) const {
|
||||||
CHECK_NOTNULL(value);
|
CHECK_NOTNULL(value);
|
||||||
const auto it = cache.find(Key{PyObjectPtr{value}, dtype, device_name});
|
const auto it = cache.find(Key{PyObjectPtr{value}, dtype, ctx, device_name});
|
||||||
if (it == cache.end()) {
|
if (it == cache.end()) {
|
||||||
scalar_cache_misses->GetCell()->IncrementBy(1);
|
scalar_cache_misses->GetCell()->IncrementBy(1);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
@ -53,10 +53,11 @@ TFE_TensorHandle* TFE_TensorHandleCache::Lookup(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void TFE_TensorHandleCache::Insert(PyObject* value, tensorflow::DataType dtype,
|
void TFE_TensorHandleCache::Insert(PyObject* value, tensorflow::DataType dtype,
|
||||||
|
TFE_Context* ctx,
|
||||||
absl::string_view device_name,
|
absl::string_view device_name,
|
||||||
TFE_TensorHandle* h) {
|
TFE_TensorHandle* h) {
|
||||||
Py_INCREF(value);
|
Py_INCREF(value);
|
||||||
cache.emplace(Key{PyObjectPtr{value}, dtype, device_name},
|
cache.emplace(Key{PyObjectPtr{value}, dtype, ctx, device_name},
|
||||||
tensorflow::wrap(tensorflow::unwrap(h)->Copy()));
|
tensorflow::wrap(tensorflow::unwrap(h)->Copy()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -73,16 +73,20 @@ struct TFE_TensorHandleCache {
|
|||||||
~TFE_TensorHandleCache() { DecrefUnrefAll(); }
|
~TFE_TensorHandleCache() { DecrefUnrefAll(); }
|
||||||
|
|
||||||
TFE_TensorHandle* Lookup(PyObject* value, tensorflow::DataType dtype,
|
TFE_TensorHandle* Lookup(PyObject* value, tensorflow::DataType dtype,
|
||||||
|
TFE_Context* ctx,
|
||||||
absl::string_view device_name) const;
|
absl::string_view device_name) const;
|
||||||
|
|
||||||
void Insert(PyObject* value, tensorflow::DataType dtype,
|
void Insert(PyObject* value, tensorflow::DataType dtype, TFE_Context* ctx,
|
||||||
absl::string_view device_name, TFE_TensorHandle* h);
|
absl::string_view device_name, TFE_TensorHandle* h);
|
||||||
|
|
||||||
void Clear();
|
void Clear();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// TODO(slebedev): should the key depend on TFE_Context?
|
// TODO(kkb): Instead of `TFE_Context*` key, ideally Python's context object
|
||||||
using Key = std::tuple<PyObjectPtr, tensorflow::DataType, absl::string_view>;
|
// should have TFE_TensorHandleCache instance. Migrate once we Python context
|
||||||
|
// object is backed by C++ data structure. b/169790439
|
||||||
|
using Key = std::tuple<PyObjectPtr, tensorflow::DataType, TFE_Context*,
|
||||||
|
absl::string_view>;
|
||||||
|
|
||||||
void DecrefUnrefAll() {
|
void DecrefUnrefAll() {
|
||||||
for (const auto& p : cache) {
|
for (const auto& p : cache) {
|
||||||
|
Loading…
Reference in New Issue
Block a user