Place all py_func op on the local host's address space.
PiperOrigin-RevId: 289903686 Change-Id: I38f3b8020cea5b3eab1e5d9141c32350473dadfa
This commit is contained in:
parent
cd326c6548
commit
f18ffa8204
@ -18,7 +18,6 @@ limitations under the License.
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/lib/monitoring/counter.h"
|
||||
#include "tensorflow/core/lib/monitoring/gauge.h"
|
||||
#include "tensorflow/core/lib/monitoring/sampler.h"
|
||||
@ -620,16 +619,3 @@ void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) {
|
||||
TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) {
|
||||
return new TFE_Executor(&ctx->context->Executor());
|
||||
}
|
||||
|
||||
void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
|
||||
auto address_space = tensorflow::DeviceNameUtils::AddressSpace(
|
||||
ctx->context->HostCPU()->parsed_name());
|
||||
auto str = tensorflow::DeviceNameUtils::ParsedNameToString(address_space);
|
||||
void* data = tensorflow::port::Malloc(str.length());
|
||||
str.copy(static_cast<char*>(data), str.length(), 0);
|
||||
buf->data = data;
|
||||
buf->length = str.length();
|
||||
buf->data_deallocator = [](void* data, size_t length) {
|
||||
tensorflow::port::Free(data);
|
||||
};
|
||||
}
|
||||
|
@ -458,11 +458,6 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
||||
void (*deallocator)(void* data, size_t len, void* arg),
|
||||
void* deallocator_arg, TF_Status* status);
|
||||
|
||||
// Retrieves the address space (i.e. job, replia, task) of the local host and
|
||||
// saves it in the buffer.
|
||||
TF_CAPI_EXPORT extern void TFE_HostAddressSpace(TFE_Context* ctx,
|
||||
TF_Buffer* buf);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
#endif
|
||||
|
@ -785,13 +785,6 @@ class Context(object):
|
||||
"""List of the names of devices available to execute operations."""
|
||||
return self._devices
|
||||
|
||||
def host_address_space(self):
|
||||
self.ensure_initialized()
|
||||
with c_api_util.tf_buffer() as buffer_:
|
||||
pywrap_tfe.TFE_HostAddressSpace(self._context_handle, buffer_)
|
||||
address_space = pywrap_tfe.TF_GetBuffer(buffer_).decode("utf-8")
|
||||
return address_space
|
||||
|
||||
# TODO(fishx): remove this property.
|
||||
@property
|
||||
def execution_mode(self):
|
||||
|
@ -31,7 +31,6 @@ from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.framework import config
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
@ -561,7 +560,7 @@ class EagerPyFuncTest(PyFuncTestBase):
|
||||
with ops.device("/job:worker/task:0/cpu:0"):
|
||||
a = array_ops.ones((3, 3), dtype=dtypes.float32)
|
||||
x = array_ops.ones((3, 1), dtype=dtypes.float32)
|
||||
output = math_ops.matmul(a, x)
|
||||
output = script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.float32)
|
||||
ret = session.run(output)
|
||||
self.assertAllClose(ret, [[3.0], [3.0], [3.0]])
|
||||
|
||||
@ -740,39 +739,6 @@ class EagerPyFuncTest(PyFuncTestBase):
|
||||
self.assertEqual(y, 1.0)
|
||||
self.assertEqual(dy_dx, 2.0)
|
||||
|
||||
def testEagerPyFuncPlacement(self):
|
||||
|
||||
def f(x):
|
||||
return math_ops.square(x)
|
||||
|
||||
def get_device(tensor):
|
||||
if isinstance(tensor, ops.EagerTensor):
|
||||
return tensor.device
|
||||
else:
|
||||
return tensor.op.device
|
||||
|
||||
const_op = constant_op.constant(3.0, dtype=dtypes.float32)
|
||||
# PyFuncOp should be placed on the localhost's address space.
|
||||
py_func_op = script_ops.eager_py_func(
|
||||
func=f, inp=[const_op], Tout=dtypes.float32)
|
||||
self.assertRegexpMatches(
|
||||
get_device(py_func_op), "/job:localhost/replica:0/task:0")
|
||||
self.assertEqual(self.evaluate(py_func_op), 9.0)
|
||||
|
||||
# Only run the remaining test if there exists GPU device.
|
||||
if not config.list_physical_devices("GPU"):
|
||||
return
|
||||
|
||||
with test_util.device(use_gpu=True):
|
||||
py_func_op = script_ops.eager_py_func(
|
||||
func=f, inp=[const_op], Tout=dtypes.float32)
|
||||
# PyFuncOp should be placed on the GPU device within localhost's address
|
||||
# space.
|
||||
self.assertEqual(
|
||||
get_device(py_func_op),
|
||||
"/job:localhost/replica:0/task:0/device:GPU:0")
|
||||
self.assertEqual(self.evaluate(py_func_op), 9.0)
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testEagerRespectsDevicePlacmentOfOp(self):
|
||||
|
||||
|
@ -449,9 +449,7 @@ def eager_py_func(func, inp, Tout, name=None):
|
||||
A list of `Tensor` or a single `Tensor` which `func` computes; an empty list
|
||||
if `func` returns None.
|
||||
"""
|
||||
with ops.device(context.context().host_address_space()):
|
||||
return _internal_py_func(
|
||||
func=func, inp=inp, Tout=Tout, eager=True, name=name)
|
||||
return _internal_py_func(func=func, inp=inp, Tout=Tout, eager=True, name=name)
|
||||
|
||||
|
||||
def py_func_common(func, inp, Tout, stateful=True, name=None):
|
||||
@ -520,14 +518,8 @@ def py_func_common(func, inp, Tout, stateful=True, name=None):
|
||||
result, = result
|
||||
return result
|
||||
|
||||
with ops.device(context.context().host_address_space()):
|
||||
return _internal_py_func(
|
||||
func=func,
|
||||
inp=inp,
|
||||
Tout=Tout,
|
||||
stateful=stateful,
|
||||
eager=False,
|
||||
name=name)
|
||||
return _internal_py_func(
|
||||
func=func, inp=inp, Tout=Tout, stateful=stateful, eager=False, name=name)
|
||||
|
||||
|
||||
@deprecation.deprecated(
|
||||
|
@ -364,9 +364,6 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
||||
return output;
|
||||
},
|
||||
py::return_value_policy::reference);
|
||||
m.def("TFE_HostAddressSpace", [](py::handle& o, TF_Buffer& buf) {
|
||||
TFE_HostAddressSpace(tensorflow::InputTFE_Context(o), &buf);
|
||||
});
|
||||
m.def("TFE_ContextAddFunction", [](py::handle& ctx, py::handle& func) {
|
||||
tensorflow::Safe_TF_StatusPtr status =
|
||||
tensorflow::make_safe(TF_NewStatus());
|
||||
|
Loading…
Reference in New Issue
Block a user