Place all py_func op on the local host's address space.
PiperOrigin-RevId: 289903686 Change-Id: I38f3b8020cea5b3eab1e5d9141c32350473dadfa
This commit is contained in:
parent
cd326c6548
commit
f18ffa8204
|
@ -18,7 +18,6 @@ limitations under the License.
|
||||||
#include "tensorflow/c/c_api.h"
|
#include "tensorflow/c/c_api.h"
|
||||||
#include "tensorflow/c/eager/c_api_internal.h"
|
#include "tensorflow/c/eager/c_api_internal.h"
|
||||||
#include "tensorflow/c/tf_status_helper.h"
|
#include "tensorflow/c/tf_status_helper.h"
|
||||||
#include "tensorflow/core/common_runtime/device.h"
|
|
||||||
#include "tensorflow/core/lib/monitoring/counter.h"
|
#include "tensorflow/core/lib/monitoring/counter.h"
|
||||||
#include "tensorflow/core/lib/monitoring/gauge.h"
|
#include "tensorflow/core/lib/monitoring/gauge.h"
|
||||||
#include "tensorflow/core/lib/monitoring/sampler.h"
|
#include "tensorflow/core/lib/monitoring/sampler.h"
|
||||||
|
@ -620,16 +619,3 @@ void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) {
|
||||||
TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) {
|
TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) {
|
||||||
return new TFE_Executor(&ctx->context->Executor());
|
return new TFE_Executor(&ctx->context->Executor());
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
|
|
||||||
auto address_space = tensorflow::DeviceNameUtils::AddressSpace(
|
|
||||||
ctx->context->HostCPU()->parsed_name());
|
|
||||||
auto str = tensorflow::DeviceNameUtils::ParsedNameToString(address_space);
|
|
||||||
void* data = tensorflow::port::Malloc(str.length());
|
|
||||||
str.copy(static_cast<char*>(data), str.length(), 0);
|
|
||||||
buf->data = data;
|
|
||||||
buf->length = str.length();
|
|
||||||
buf->data_deallocator = [](void* data, size_t length) {
|
|
||||||
tensorflow::port::Free(data);
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
|
@ -458,11 +458,6 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
||||||
void (*deallocator)(void* data, size_t len, void* arg),
|
void (*deallocator)(void* data, size_t len, void* arg),
|
||||||
void* deallocator_arg, TF_Status* status);
|
void* deallocator_arg, TF_Status* status);
|
||||||
|
|
||||||
// Retrieves the address space (i.e. job, replia, task) of the local host and
|
|
||||||
// saves it in the buffer.
|
|
||||||
TF_CAPI_EXPORT extern void TFE_HostAddressSpace(TFE_Context* ctx,
|
|
||||||
TF_Buffer* buf);
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
} /* end extern "C" */
|
} /* end extern "C" */
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -785,13 +785,6 @@ class Context(object):
|
||||||
"""List of the names of devices available to execute operations."""
|
"""List of the names of devices available to execute operations."""
|
||||||
return self._devices
|
return self._devices
|
||||||
|
|
||||||
def host_address_space(self):
|
|
||||||
self.ensure_initialized()
|
|
||||||
with c_api_util.tf_buffer() as buffer_:
|
|
||||||
pywrap_tfe.TFE_HostAddressSpace(self._context_handle, buffer_)
|
|
||||||
address_space = pywrap_tfe.TF_GetBuffer(buffer_).decode("utf-8")
|
|
||||||
return address_space
|
|
||||||
|
|
||||||
# TODO(fishx): remove this property.
|
# TODO(fishx): remove this property.
|
||||||
@property
|
@property
|
||||||
def execution_mode(self):
|
def execution_mode(self):
|
||||||
|
|
|
@ -31,7 +31,6 @@ from tensorflow.python.eager import backprop
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.eager import function
|
from tensorflow.python.eager import function
|
||||||
from tensorflow.python.framework import config
|
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
|
@ -561,7 +560,7 @@ class EagerPyFuncTest(PyFuncTestBase):
|
||||||
with ops.device("/job:worker/task:0/cpu:0"):
|
with ops.device("/job:worker/task:0/cpu:0"):
|
||||||
a = array_ops.ones((3, 3), dtype=dtypes.float32)
|
a = array_ops.ones((3, 3), dtype=dtypes.float32)
|
||||||
x = array_ops.ones((3, 1), dtype=dtypes.float32)
|
x = array_ops.ones((3, 1), dtype=dtypes.float32)
|
||||||
output = math_ops.matmul(a, x)
|
output = script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.float32)
|
||||||
ret = session.run(output)
|
ret = session.run(output)
|
||||||
self.assertAllClose(ret, [[3.0], [3.0], [3.0]])
|
self.assertAllClose(ret, [[3.0], [3.0], [3.0]])
|
||||||
|
|
||||||
|
@ -740,39 +739,6 @@ class EagerPyFuncTest(PyFuncTestBase):
|
||||||
self.assertEqual(y, 1.0)
|
self.assertEqual(y, 1.0)
|
||||||
self.assertEqual(dy_dx, 2.0)
|
self.assertEqual(dy_dx, 2.0)
|
||||||
|
|
||||||
def testEagerPyFuncPlacement(self):
|
|
||||||
|
|
||||||
def f(x):
|
|
||||||
return math_ops.square(x)
|
|
||||||
|
|
||||||
def get_device(tensor):
|
|
||||||
if isinstance(tensor, ops.EagerTensor):
|
|
||||||
return tensor.device
|
|
||||||
else:
|
|
||||||
return tensor.op.device
|
|
||||||
|
|
||||||
const_op = constant_op.constant(3.0, dtype=dtypes.float32)
|
|
||||||
# PyFuncOp should be placed on the localhost's address space.
|
|
||||||
py_func_op = script_ops.eager_py_func(
|
|
||||||
func=f, inp=[const_op], Tout=dtypes.float32)
|
|
||||||
self.assertRegexpMatches(
|
|
||||||
get_device(py_func_op), "/job:localhost/replica:0/task:0")
|
|
||||||
self.assertEqual(self.evaluate(py_func_op), 9.0)
|
|
||||||
|
|
||||||
# Only run the remaining test if there exists GPU device.
|
|
||||||
if not config.list_physical_devices("GPU"):
|
|
||||||
return
|
|
||||||
|
|
||||||
with test_util.device(use_gpu=True):
|
|
||||||
py_func_op = script_ops.eager_py_func(
|
|
||||||
func=f, inp=[const_op], Tout=dtypes.float32)
|
|
||||||
# PyFuncOp should be placed on the GPU device within localhost's address
|
|
||||||
# space.
|
|
||||||
self.assertEqual(
|
|
||||||
get_device(py_func_op),
|
|
||||||
"/job:localhost/replica:0/task:0/device:GPU:0")
|
|
||||||
self.assertEqual(self.evaluate(py_func_op), 9.0)
|
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_v1_only("b/120545219")
|
||||||
def testEagerRespectsDevicePlacmentOfOp(self):
|
def testEagerRespectsDevicePlacmentOfOp(self):
|
||||||
|
|
||||||
|
|
|
@ -449,9 +449,7 @@ def eager_py_func(func, inp, Tout, name=None):
|
||||||
A list of `Tensor` or a single `Tensor` which `func` computes; an empty list
|
A list of `Tensor` or a single `Tensor` which `func` computes; an empty list
|
||||||
if `func` returns None.
|
if `func` returns None.
|
||||||
"""
|
"""
|
||||||
with ops.device(context.context().host_address_space()):
|
return _internal_py_func(func=func, inp=inp, Tout=Tout, eager=True, name=name)
|
||||||
return _internal_py_func(
|
|
||||||
func=func, inp=inp, Tout=Tout, eager=True, name=name)
|
|
||||||
|
|
||||||
|
|
||||||
def py_func_common(func, inp, Tout, stateful=True, name=None):
|
def py_func_common(func, inp, Tout, stateful=True, name=None):
|
||||||
|
@ -520,14 +518,8 @@ def py_func_common(func, inp, Tout, stateful=True, name=None):
|
||||||
result, = result
|
result, = result
|
||||||
return result
|
return result
|
||||||
|
|
||||||
with ops.device(context.context().host_address_space()):
|
return _internal_py_func(
|
||||||
return _internal_py_func(
|
func=func, inp=inp, Tout=Tout, stateful=stateful, eager=False, name=name)
|
||||||
func=func,
|
|
||||||
inp=inp,
|
|
||||||
Tout=Tout,
|
|
||||||
stateful=stateful,
|
|
||||||
eager=False,
|
|
||||||
name=name)
|
|
||||||
|
|
||||||
|
|
||||||
@deprecation.deprecated(
|
@deprecation.deprecated(
|
||||||
|
|
|
@ -364,9 +364,6 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
||||||
return output;
|
return output;
|
||||||
},
|
},
|
||||||
py::return_value_policy::reference);
|
py::return_value_policy::reference);
|
||||||
m.def("TFE_HostAddressSpace", [](py::handle& o, TF_Buffer& buf) {
|
|
||||||
TFE_HostAddressSpace(tensorflow::InputTFE_Context(o), &buf);
|
|
||||||
});
|
|
||||||
m.def("TFE_ContextAddFunction", [](py::handle& ctx, py::handle& func) {
|
m.def("TFE_ContextAddFunction", [](py::handle& ctx, py::handle& func) {
|
||||||
tensorflow::Safe_TF_StatusPtr status =
|
tensorflow::Safe_TF_StatusPtr status =
|
||||||
tensorflow::make_safe(TF_NewStatus());
|
tensorflow::make_safe(TF_NewStatus());
|
||||||
|
|
Loading…
Reference in New Issue