Place all py_func op on the local host's address space.

PiperOrigin-RevId: 289903686
Change-Id: I38f3b8020cea5b3eab1e5d9141c32350473dadfa
This commit is contained in:
A. Unique TensorFlower 2020-01-15 11:41:45 -08:00 committed by TensorFlower Gardener
parent cd326c6548
commit f18ffa8204
6 changed files with 4 additions and 75 deletions

View File

@ -18,7 +18,6 @@ limitations under the License.
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/lib/monitoring/counter.h"
#include "tensorflow/core/lib/monitoring/gauge.h"
#include "tensorflow/core/lib/monitoring/sampler.h"
@ -620,16 +619,3 @@ void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) {
TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) {
return new TFE_Executor(&ctx->context->Executor());
}
void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
auto address_space = tensorflow::DeviceNameUtils::AddressSpace(
ctx->context->HostCPU()->parsed_name());
auto str = tensorflow::DeviceNameUtils::ParsedNameToString(address_space);
void* data = tensorflow::port::Malloc(str.length());
str.copy(static_cast<char*>(data), str.length(), 0);
buf->data = data;
buf->length = str.length();
buf->data_deallocator = [](void* data, size_t length) {
tensorflow::port::Free(data);
};
}

View File

@ -458,11 +458,6 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
void (*deallocator)(void* data, size_t len, void* arg),
void* deallocator_arg, TF_Status* status);
// Retrieves the address space (i.e. job, replia, task) of the local host and
// saves it in the buffer.
TF_CAPI_EXPORT extern void TFE_HostAddressSpace(TFE_Context* ctx,
TF_Buffer* buf);
#ifdef __cplusplus
} /* end extern "C" */
#endif

View File

@ -785,13 +785,6 @@ class Context(object):
"""List of the names of devices available to execute operations."""
return self._devices
def host_address_space(self):
self.ensure_initialized()
with c_api_util.tf_buffer() as buffer_:
pywrap_tfe.TFE_HostAddressSpace(self._context_handle, buffer_)
address_space = pywrap_tfe.TF_GetBuffer(buffer_).decode("utf-8")
return address_space
# TODO(fishx): remove this property.
@property
def execution_mode(self):

View File

@ -31,7 +31,6 @@ from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import function
from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@ -561,7 +560,7 @@ class EagerPyFuncTest(PyFuncTestBase):
with ops.device("/job:worker/task:0/cpu:0"):
a = array_ops.ones((3, 3), dtype=dtypes.float32)
x = array_ops.ones((3, 1), dtype=dtypes.float32)
output = math_ops.matmul(a, x)
output = script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.float32)
ret = session.run(output)
self.assertAllClose(ret, [[3.0], [3.0], [3.0]])
@ -740,39 +739,6 @@ class EagerPyFuncTest(PyFuncTestBase):
self.assertEqual(y, 1.0)
self.assertEqual(dy_dx, 2.0)
def testEagerPyFuncPlacement(self):
def f(x):
return math_ops.square(x)
def get_device(tensor):
if isinstance(tensor, ops.EagerTensor):
return tensor.device
else:
return tensor.op.device
const_op = constant_op.constant(3.0, dtype=dtypes.float32)
# PyFuncOp should be placed on the localhost's address space.
py_func_op = script_ops.eager_py_func(
func=f, inp=[const_op], Tout=dtypes.float32)
self.assertRegexpMatches(
get_device(py_func_op), "/job:localhost/replica:0/task:0")
self.assertEqual(self.evaluate(py_func_op), 9.0)
# Only run the remaining test if there exists GPU device.
if not config.list_physical_devices("GPU"):
return
with test_util.device(use_gpu=True):
py_func_op = script_ops.eager_py_func(
func=f, inp=[const_op], Tout=dtypes.float32)
# PyFuncOp should be placed on the GPU device within localhost's address
# space.
self.assertEqual(
get_device(py_func_op),
"/job:localhost/replica:0/task:0/device:GPU:0")
self.assertEqual(self.evaluate(py_func_op), 9.0)
@test_util.run_v1_only("b/120545219")
def testEagerRespectsDevicePlacmentOfOp(self):

View File

@ -449,9 +449,7 @@ def eager_py_func(func, inp, Tout, name=None):
A list of `Tensor` or a single `Tensor` which `func` computes; an empty list
if `func` returns None.
"""
with ops.device(context.context().host_address_space()):
return _internal_py_func(
func=func, inp=inp, Tout=Tout, eager=True, name=name)
return _internal_py_func(func=func, inp=inp, Tout=Tout, eager=True, name=name)
def py_func_common(func, inp, Tout, stateful=True, name=None):
@ -520,14 +518,8 @@ def py_func_common(func, inp, Tout, stateful=True, name=None):
result, = result
return result
with ops.device(context.context().host_address_space()):
return _internal_py_func(
func=func,
inp=inp,
Tout=Tout,
stateful=stateful,
eager=False,
name=name)
return _internal_py_func(
func=func, inp=inp, Tout=Tout, stateful=stateful, eager=False, name=name)
@deprecation.deprecated(

View File

@ -364,9 +364,6 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
return output;
},
py::return_value_policy::reference);
m.def("TFE_HostAddressSpace", [](py::handle& o, TF_Buffer& buf) {
TFE_HostAddressSpace(tensorflow::InputTFE_Context(o), &buf);
});
m.def("TFE_ContextAddFunction", [](py::handle& ctx, py::handle& func) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());