Use TF_Status* SWIG typemap
This avoids the need to use raise_exception_on_not_ok_status which has been removed from the v2 API. PiperOrigin-RevId: 244049607
This commit is contained in:
parent
4509bdfd5c
commit
3afb32ce1f
@ -29,8 +29,7 @@ namespace checkpoint {
|
||||
|
||||
class TensorSliceReader;
|
||||
|
||||
CheckpointReader::CheckpointReader(const string& filename,
|
||||
TF_Status* out_status)
|
||||
CheckpointReader::CheckpointReader(const string& filename, TF_Status* status)
|
||||
: reader_(nullptr),
|
||||
v2_reader_(nullptr),
|
||||
var_to_shape_map_(nullptr),
|
||||
@ -43,7 +42,7 @@ CheckpointReader::CheckpointReader(const string& filename,
|
||||
v2_reader_.reset(
|
||||
new BundleReader(Env::Default(), filename /* prefix to a V2 ckpt */));
|
||||
if (!v2_reader_->status().ok()) {
|
||||
Set_TF_Status_from_Status(out_status, v2_reader_->status());
|
||||
Set_TF_Status_from_Status(status, v2_reader_->status());
|
||||
return;
|
||||
}
|
||||
auto result = BuildV2VarMaps();
|
||||
@ -52,7 +51,7 @@ CheckpointReader::CheckpointReader(const string& filename,
|
||||
} else {
|
||||
reader_.reset(new TensorSliceReader(filename));
|
||||
if (!reader_->status().ok()) {
|
||||
Set_TF_Status_from_Status(out_status, reader_->status());
|
||||
Set_TF_Status_from_Status(status, reader_->status());
|
||||
return;
|
||||
}
|
||||
var_to_shape_map_.reset(
|
||||
|
@ -39,7 +39,7 @@ class TensorSliceReader;
|
||||
// variables.
|
||||
class CheckpointReader {
|
||||
public:
|
||||
CheckpointReader(const string& filepattern, TF_Status* out_status);
|
||||
CheckpointReader(const string& filename, TF_Status* status);
|
||||
|
||||
bool HasTensor(const string& name) const;
|
||||
const string DebugString() const;
|
||||
|
@ -44,15 +44,14 @@ namespace tensorflow {
|
||||
namespace swig {
|
||||
|
||||
static std::vector<string> ListDevicesWithSessionConfig(
|
||||
const tensorflow::ConfigProto& config, TF_Status* out_status) {
|
||||
const tensorflow::ConfigProto& config, TF_Status* status) {
|
||||
std::vector<string> output;
|
||||
SessionOptions options;
|
||||
options.config = config;
|
||||
std::vector<std::unique_ptr<Device>> devices;
|
||||
Status status = DeviceFactory::AddDevices(
|
||||
options, "" /* name_prefix */, &devices);
|
||||
if (!status.ok()) {
|
||||
Set_TF_Status_from_Status(out_status, status);
|
||||
Status s = DeviceFactory::AddDevices(options, "" /* name_prefix */, &devices);
|
||||
if (!s.ok()) {
|
||||
Set_TF_Status_from_Status(status, s);
|
||||
}
|
||||
|
||||
for (const std::unique_ptr<Device>& device : devices) {
|
||||
@ -60,7 +59,7 @@ static std::vector<string> ListDevicesWithSessionConfig(
|
||||
string attr_serialized;
|
||||
if (!attr.SerializeToString(&attr_serialized)) {
|
||||
Set_TF_Status_from_Status(
|
||||
out_status,
|
||||
status,
|
||||
errors::Internal("Could not serialize device string"));
|
||||
output.clear();
|
||||
return output;
|
||||
@ -71,9 +70,9 @@ static std::vector<string> ListDevicesWithSessionConfig(
|
||||
return output;
|
||||
}
|
||||
|
||||
std::vector<string> ListDevices(TF_Status* out_status) {
|
||||
std::vector<string> ListDevices(TF_Status* status) {
|
||||
tensorflow::ConfigProto session_config;
|
||||
return ListDevicesWithSessionConfig(session_config, out_status);
|
||||
return ListDevicesWithSessionConfig(session_config, status);
|
||||
}
|
||||
|
||||
} // namespace swig
|
||||
@ -91,9 +90,9 @@ std::vector<string> ListDevices(TF_Status* out_status) {
|
||||
// Wrap this function
|
||||
namespace tensorflow {
|
||||
namespace swig {
|
||||
std::vector<string> ListDevices(TF_Status* out_status);
|
||||
std::vector<string> ListDevices(TF_Status* status);
|
||||
static std::vector<string> ListDevicesWithSessionConfig(
|
||||
const tensorflow::ConfigProto& config, TF_Status* out_status);
|
||||
const tensorflow::ConfigProto& config, TF_Status* status);
|
||||
} // namespace swig
|
||||
} // namespace tensorflow
|
||||
|
||||
@ -101,12 +100,10 @@ static std::vector<string> ListDevicesWithSessionConfig(
|
||||
def list_devices(session_config=None):
|
||||
from tensorflow.python.framework import errors
|
||||
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
if session_config:
|
||||
return ListDevicesWithSessionConfig(session_config.SerializeToString(),
|
||||
status)
|
||||
return ListDevicesWithSessionConfig(session_config.SerializeToString())
|
||||
else:
|
||||
return ListDevices(status)
|
||||
return ListDevices()
|
||||
%}
|
||||
|
||||
%unignoreall
|
||||
|
@ -1425,9 +1425,8 @@ class BaseSession(SessionInterface):
|
||||
options_ptr = tf_session.TF_NewBufferFromString(
|
||||
compat.as_bytes(callable_options.SerializeToString()))
|
||||
try:
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
self._handle = tf_session.TF_SessionMakeCallable(
|
||||
session._session, options_ptr, status)
|
||||
session._session, options_ptr)
|
||||
finally:
|
||||
tf_session.TF_DeleteBuffer(options_ptr)
|
||||
|
||||
@ -1437,10 +1436,8 @@ class BaseSession(SessionInterface):
|
||||
run_metadata = kwargs.get('run_metadata', None)
|
||||
try:
|
||||
run_metadata_ptr = tf_session.TF_NewBuffer() if run_metadata else None
|
||||
# TODO(mrry): Switch to raising an exception from the SWIG wrapper.
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
ret = tf_session.TF_SessionRunCallable(
|
||||
self._session._session, self._handle, args, status,
|
||||
ret = tf_session.TF_SessionRunCallable(self._session._session,
|
||||
self._handle, args,
|
||||
run_metadata_ptr)
|
||||
if run_metadata:
|
||||
proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
|
||||
@ -1455,9 +1452,8 @@ class BaseSession(SessionInterface):
|
||||
# called before this destructor, in which case `self._session._session`
|
||||
# will be `None`.
|
||||
if self._handle is not None and self._session._session is not None:
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
tf_session.TF_SessionReleaseCallable(
|
||||
self._session._session, self._handle, status)
|
||||
self._session._session, self._handle)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
# TODO(b/74355905): Reimplement `Session.make_callable()` using this method
|
||||
|
@ -571,8 +571,7 @@ def TF_Reset(target, containers=None, config=None):
|
||||
from tensorflow.python.framework import errors
|
||||
opts = TF_NewSessionOptions(target=target, config=config)
|
||||
try:
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
TF_Reset_wrapper(opts, containers, status)
|
||||
TF_Reset_wrapper(opts, containers)
|
||||
finally:
|
||||
TF_DeleteSessionOptions(opts)
|
||||
%}
|
||||
|
@ -194,16 +194,13 @@ void MakeCallableHelper(tensorflow::Session* session,
|
||||
|
||||
void TF_DeprecatedSessionMakeCallable(TF_DeprecatedSession* session,
|
||||
const TF_Buffer* callable_options,
|
||||
int64_t* out_handle,
|
||||
TF_Status* out_status) {
|
||||
MakeCallableHelper(session->session, callable_options, out_handle,
|
||||
out_status);
|
||||
int64_t* out_handle, TF_Status* status) {
|
||||
MakeCallableHelper(session->session, callable_options, out_handle, status);
|
||||
}
|
||||
void TF_SessionMakeCallable(TF_Session* session,
|
||||
const TF_Buffer* callable_options,
|
||||
int64_t* out_handle, TF_Status* out_status) {
|
||||
MakeCallableHelper(session->session, callable_options, out_handle,
|
||||
out_status);
|
||||
int64_t* out_handle, TF_Status* status) {
|
||||
MakeCallableHelper(session->session, callable_options, out_handle, status);
|
||||
}
|
||||
|
||||
namespace {
|
||||
@ -291,32 +288,28 @@ void RunCallableHelper(tensorflow::Session* session, int64_t handle,
|
||||
|
||||
void TF_DeprecatedSessionRunCallable(TF_DeprecatedSession* session,
|
||||
int64_t handle, PyObject* feed_values,
|
||||
TF_Status* out_status,
|
||||
PyObjectVector* out_values,
|
||||
TF_Buffer* run_metadata) {
|
||||
RunCallableHelper(session->session, handle, feed_values, out_status,
|
||||
out_values, run_metadata);
|
||||
TF_Buffer* run_metadata,
|
||||
TF_Status* status) {
|
||||
RunCallableHelper(session->session, handle, feed_values, status, out_values,
|
||||
run_metadata);
|
||||
ClearDecrefCache();
|
||||
}
|
||||
void TF_SessionRunCallable(TF_Session* session, int64_t handle,
|
||||
PyObject* feed_values, TF_Status* out_status,
|
||||
PyObjectVector* out_values,
|
||||
TF_Buffer* run_metadata) {
|
||||
RunCallableHelper(session->session, handle, feed_values, out_status,
|
||||
out_values, run_metadata);
|
||||
PyObject* feed_values, PyObjectVector* out_values,
|
||||
TF_Buffer* run_metadata, TF_Status* status) {
|
||||
RunCallableHelper(session->session, handle, feed_values, status, out_values,
|
||||
run_metadata);
|
||||
ClearDecrefCache();
|
||||
}
|
||||
|
||||
void TF_DeprecatedSessionReleaseCallable(TF_DeprecatedSession* session,
|
||||
int64_t handle,
|
||||
TF_Status* out_status) {
|
||||
Set_TF_Status_from_Status(out_status,
|
||||
session->session->ReleaseCallable(handle));
|
||||
int64_t handle, TF_Status* status) {
|
||||
Set_TF_Status_from_Status(status, session->session->ReleaseCallable(handle));
|
||||
}
|
||||
void TF_SessionReleaseCallable(TF_Session* session, int64_t handle,
|
||||
TF_Status* out_status) {
|
||||
Set_TF_Status_from_Status(out_status,
|
||||
session->session->ReleaseCallable(handle));
|
||||
TF_Status* status) {
|
||||
Set_TF_Status_from_Status(status, session->session->ReleaseCallable(handle));
|
||||
}
|
||||
|
||||
// Wrapper for TF_PRunSetup that converts the arguments to appropriate types.
|
||||
@ -348,9 +341,9 @@ void TF_PRun_wrapper(TF_DeprecatedSession* session, const char* handle,
|
||||
|
||||
// Wrapper for TF_Reset that converts the string vectors to character arrays.
|
||||
void TF_Reset_wrapper(const TF_SessionOptions* opt,
|
||||
const NameVector& containers, TF_Status* out_status) {
|
||||
const NameVector& containers, TF_Status* status) {
|
||||
TF_Reset(opt, const_cast<const char**>(containers.data()), containers.size(),
|
||||
out_status);
|
||||
status);
|
||||
}
|
||||
|
||||
void TF_SessionRun_wrapper_helper(TF_Session* session, const char* handle,
|
||||
|
@ -65,27 +65,26 @@ void TF_Run_wrapper(TF_DeprecatedSession* session, const TF_Buffer* run_options,
|
||||
// Python wrappers for the `Session::MakeCallable()` API.
|
||||
void TF_DeprecatedSessionMakeCallable(TF_DeprecatedSession* session,
|
||||
const TF_Buffer* callable_options,
|
||||
int64_t* out_handle,
|
||||
TF_Status* out_status);
|
||||
int64_t* out_handle, TF_Status* status);
|
||||
void TF_SessionMakeCallable(TF_Session* session,
|
||||
const TF_Buffer* callable_options,
|
||||
int64_t* out_handle, TF_Status* out_status);
|
||||
int64_t* out_handle, TF_Status* status);
|
||||
|
||||
// Python wrappers for the `Session::RunCallable()` API.
|
||||
void TF_DeprecatedSessionRunCallable(TF_DeprecatedSession* session,
|
||||
int64_t handle, PyObject* feed_values,
|
||||
TF_Status* out_status,
|
||||
PyObjectVector* out_values,
|
||||
TF_Buffer* run_metadata);
|
||||
TF_Buffer* run_metadata,
|
||||
TF_Status* status);
|
||||
void TF_SessionRunCallable(TF_Session* session, int64_t handle,
|
||||
PyObject* feed_values, TF_Status* out_status,
|
||||
PyObjectVector* out_values, TF_Buffer* run_metadata);
|
||||
PyObject* feed_values, PyObjectVector* out_values,
|
||||
TF_Buffer* run_metadata, TF_Status* status);
|
||||
|
||||
// Python wrappers for the `Session::ReleaseCallable()` API.
|
||||
void TF_DeprecatedSessionReleaseCallable(TF_DeprecatedSession* session,
|
||||
int64_t handle, TF_Status* out_status);
|
||||
int64_t handle, TF_Status* status);
|
||||
void TF_SessionReleaseCallable(TF_Session* session, int64_t handle,
|
||||
TF_Status* out_status);
|
||||
TF_Status* status);
|
||||
|
||||
// Set up the graph with the intended feeds and fetches for partial run.
|
||||
// *out_handle is owned by the caller.
|
||||
@ -118,7 +117,7 @@ void TF_PRun_wrapper(TF_DeprecatedSession* session, const char* handle,
|
||||
|
||||
// Wrapper for TF_Reset that converts the string vectors to character arrays.
|
||||
void TF_Reset_wrapper(const TF_SessionOptions* opt,
|
||||
const NameVector& containers, TF_Status* out_status);
|
||||
const NameVector& containers, TF_Status* status);
|
||||
|
||||
// Convenience wrapper around EqualGraphDef to make it easier to wrap.
|
||||
// Returns an explanation if a difference is found, or the empty string
|
||||
|
@ -45,7 +45,6 @@ def start_tracing(service_addr,
|
||||
Raises:
|
||||
UnavailableError: If no trace event is collected.
|
||||
"""
|
||||
# TODO(fishx): Uses errors.raise_exception_on_not_ok_status instead.
|
||||
if not pywrap_tensorflow.TFE_ProfilerClientStartTracing(
|
||||
service_addr, logdir, worker_list, include_dataset_ops, duration_ms,
|
||||
num_tracing_attempts):
|
||||
|
@ -692,10 +692,9 @@ def _call_cpp_shape_fn_impl(
|
||||
|
||||
missing_shape_fn = False
|
||||
try:
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
output = pywrap_tensorflow.RunCppShapeInference(
|
||||
graph_def_version, node_def_str, input_shapes, input_tensors,
|
||||
input_tensors_as_shapes, status)
|
||||
input_tensors_as_shapes)
|
||||
except errors.InvalidArgumentError as err:
|
||||
if err.message.startswith("No shape inference function exists for op"):
|
||||
missing_shape_fn = True
|
||||
|
@ -175,9 +175,9 @@ std::vector<string> RunCppShapeInference(
|
||||
const std::vector<string>& input_serialized_shapes,
|
||||
PyObject* input_constant_tensor_values,
|
||||
const std::vector<string>& input_constant_tensor_as_shape_values,
|
||||
TF_Status* out_status) {
|
||||
TF_Status* status) {
|
||||
if (!PyList_Check(input_constant_tensor_values)) {
|
||||
TF_SetStatus(out_status, TF_INVALID_ARGUMENT, "Invalid python value");
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT, "Invalid python value");
|
||||
return std::vector<string>();
|
||||
}
|
||||
|
||||
@ -191,13 +191,13 @@ std::vector<string> RunCppShapeInference(
|
||||
|
||||
std::vector<string> output;
|
||||
string input_tensors_needed_out;
|
||||
tensorflow::Status status = RunCppShapeInferenceImpl(
|
||||
tensorflow::Status s = RunCppShapeInferenceImpl(
|
||||
graph_def_version, serialized_node_def, input_serialized_shapes,
|
||||
input_constant_tensor_values_v, input_constant_tensor_as_shape_values,
|
||||
&output, &input_tensors_needed_out);
|
||||
|
||||
Set_TF_Status_from_Status(out_status, status);
|
||||
if (!status.ok()) {
|
||||
Set_TF_Status_from_Status(status, s);
|
||||
if (!s.ok()) {
|
||||
return std::vector<string>();
|
||||
}
|
||||
output.push_back(input_tensors_needed_out);
|
||||
|
@ -46,7 +46,7 @@ std::vector<string> RunCppShapeInference(
|
||||
const std::vector<string>& input_serialized_shapes,
|
||||
PyObject* input_constant_tensor_values,
|
||||
const std::vector<string>& input_constant_tensor_as_shape_values,
|
||||
TF_Status* out_status);
|
||||
TF_Status* status);
|
||||
|
||||
} // namespace swig
|
||||
} // namespace tensorflow
|
||||
|
@ -112,9 +112,7 @@ class ErrorsTest(test.TestCase):
|
||||
|
||||
def testStatusDoesNotLeak(self):
|
||||
try:
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
pywrap_tensorflow.DeleteFile(
|
||||
compat.as_bytes("/DOES_NOT_EXIST/"), status)
|
||||
pywrap_tensorflow.DeleteFile(compat.as_bytes("/DOES_NOT_EXIST/"))
|
||||
except:
|
||||
pass
|
||||
gc.collect()
|
||||
|
@ -131,7 +131,7 @@ struct GCluster {
|
||||
|
||||
|
||||
static GCluster TF_NewCluster(bool allow_soft_placement,
|
||||
bool disable_detailed_stats, TF_Status* out_status) {
|
||||
bool disable_detailed_stats, TF_Status* status) {
|
||||
int num_cpu_cores = tensorflow::grappler::GetNumAvailableLogicalCPUCores();
|
||||
int num_gpus = tensorflow::grappler::GetNumAvailableGPUs();
|
||||
int timeout_s = 60 * 10;
|
||||
@ -141,14 +141,13 @@ static GCluster TF_NewCluster(bool allow_soft_placement,
|
||||
cluster_->DisableDetailedStats(disable_detailed_stats);
|
||||
cluster_->AllowSoftPlacement(allow_soft_placement);
|
||||
cluster_->SetNumWarmupSteps(10);
|
||||
tensorflow::Status status = cluster_->Provision();
|
||||
tensorflow::Set_TF_Status_from_Status(out_status, status);
|
||||
tensorflow::Status s = cluster_->Provision();
|
||||
tensorflow::Set_TF_Status_from_Status(status, s);
|
||||
return GCluster(cluster_);
|
||||
}
|
||||
|
||||
static GCluster TF_NewVirtualCluster(
|
||||
const std::vector<tensorflow::NamedDevice>& named_devices,
|
||||
TF_Status* out_status) {
|
||||
const std::vector<tensorflow::NamedDevice>& named_devices, TF_Status* status) {
|
||||
std::unordered_map<string, tensorflow::DeviceProperties> devices;
|
||||
for (const auto& named_device : named_devices) {
|
||||
devices[named_device.name()]= named_device.properties();
|
||||
@ -156,9 +155,9 @@ static GCluster TF_NewVirtualCluster(
|
||||
tensorflow::grappler::Cluster* cluster_ =
|
||||
new tensorflow::grappler::VirtualCluster(devices);
|
||||
PyGILState_STATE gstate = PyGILState_Ensure();
|
||||
tensorflow::Status status = cluster_->Provision();
|
||||
tensorflow::Status s = cluster_->Provision();
|
||||
PyGILState_Release(gstate);
|
||||
tensorflow::Set_TF_Status_from_Status(out_status, status);
|
||||
tensorflow::Set_TF_Status_from_Status(status, s);
|
||||
return GCluster(cluster_);
|
||||
}
|
||||
|
||||
@ -316,7 +315,7 @@ static double TF_EstimatePerformance(const tensorflow::NamedDevice& device) {
|
||||
static PyObject* TF_MeasureCosts(
|
||||
GItem item,
|
||||
GCluster cluster,
|
||||
bool generate_timeline, TF_Status* out_status) {
|
||||
bool generate_timeline, TF_Status* status) {
|
||||
tensorflow::OpPerformanceList op_performance_data;
|
||||
tensorflow::StepStats step_stats;
|
||||
|
||||
@ -324,25 +323,25 @@ static PyObject* TF_MeasureCosts(
|
||||
tensorflow::grappler::MeasuringCostEstimator cost_measure(cluster.get(), num_measurements, 0);
|
||||
|
||||
tensorflow::grappler::Costs costs;
|
||||
tensorflow::Status status = _GetOpPerformanceDataAndRunTime(
|
||||
tensorflow::Status s = _GetOpPerformanceDataAndRunTime(
|
||||
*item, &cost_measure, &op_performance_data, &costs);
|
||||
double run_time = FLT_MAX;
|
||||
if (status.ok()) {
|
||||
if (s.ok()) {
|
||||
run_time = static_cast<double>(costs.execution_time.count()) / 1e9;
|
||||
}
|
||||
if (generate_timeline) {
|
||||
tensorflow::RunMetadata metadata;
|
||||
tensorflow::Status s = cluster->Run(
|
||||
tensorflow::Status run_status = cluster->Run(
|
||||
item->graph, item->feed, item->fetch, &metadata);
|
||||
if (s.ok()) {
|
||||
if (run_status.ok()) {
|
||||
step_stats = metadata.step_stats();
|
||||
} else {
|
||||
status = s;
|
||||
s = run_status;
|
||||
}
|
||||
}
|
||||
|
||||
tensorflow::Set_TF_Status_from_Status(out_status, status);
|
||||
if (!status.ok()) {
|
||||
tensorflow::Set_TF_Status_from_Status(status, s);
|
||||
if (!s.ok()) {
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
PyGILState_STATE gstate = PyGILState_Ensure();
|
||||
@ -370,9 +369,9 @@ static PyObject* TF_MeasureCosts(
|
||||
Py_XDECREF(op_perf_objs);
|
||||
Py_XDECREF(run_time_obj);
|
||||
Py_XDECREF(metadata_obj);
|
||||
status = tensorflow::Status(tensorflow::error::Code::INTERNAL,
|
||||
s = tensorflow::Status(tensorflow::error::Code::INTERNAL,
|
||||
"Error setting return tuples.");
|
||||
tensorflow::Set_TF_Status_from_Status(out_status, status);
|
||||
tensorflow::Set_TF_Status_from_Status(status, s);
|
||||
Py_INCREF(Py_None);
|
||||
ret = Py_None;
|
||||
}
|
||||
@ -384,23 +383,23 @@ static PyObject* TF_MeasureCosts(
|
||||
static PyObject* TF_DeterminePeakMemoryUsage(
|
||||
GItem item,
|
||||
GCluster cluster,
|
||||
TF_Status* out_status) {
|
||||
TF_Status* status) {
|
||||
if (item.is_none() || cluster.is_none()) {
|
||||
tensorflow::Status status(tensorflow::error::Code::INTERNAL,
|
||||
tensorflow::Status s(tensorflow::error::Code::INTERNAL,
|
||||
"You need both a cluster and an item to determine peak memory usage");
|
||||
tensorflow::Set_TF_Status_from_Status(out_status, status);
|
||||
tensorflow::Set_TF_Status_from_Status(status, s);
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
tensorflow::grappler::GraphMemory memory(*item);
|
||||
|
||||
tensorflow::Status status;
|
||||
tensorflow::Status s;
|
||||
if (cluster->DetailedStatsEnabled()) {
|
||||
status = memory.InferDynamically(cluster.get());
|
||||
s = memory.InferDynamically(cluster.get());
|
||||
} else {
|
||||
status = memory.InferStatically(cluster->GetDevices());
|
||||
s = memory.InferStatically(cluster->GetDevices());
|
||||
}
|
||||
if (!status.ok()) {
|
||||
tensorflow::Set_TF_Status_from_Status(out_status, status);
|
||||
if (!s.ok()) {
|
||||
tensorflow::Set_TF_Status_from_Status(status, s);
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
@ -434,10 +433,10 @@ static PyObject* TF_DeterminePeakMemoryUsage(
|
||||
|
||||
// Wrap these functions.
|
||||
static GCluster TF_NewCluster(
|
||||
bool allow_soft_placement, bool disable_detailed_stats, TF_Status* out_status);
|
||||
bool allow_soft_placement, bool disable_detailed_stats, TF_Status* status);
|
||||
static GCluster TF_NewVirtualCluster(
|
||||
const std::vector<tensorflow::NamedDevice>& named_devices,
|
||||
TF_Status* out_status);
|
||||
TF_Status* status);
|
||||
static void TF_ShutdownCluster(GCluster cluster);
|
||||
static PyObject* TF_ListDevices(GCluster cluster);
|
||||
static PyObject* TF_ListAvailableOps();
|
||||
@ -445,7 +444,7 @@ static PyObject* TF_GetSupportedDevices(GCluster cluster, GItem item);
|
||||
static float TF_EstimatePerformance(const tensorflow::NamedDevice& device);
|
||||
static PyObject* TF_MeasureCosts(
|
||||
GItem item, GCluster cluster,
|
||||
bool generate_timeline, TF_Status* out_status);
|
||||
bool generate_timeline, TF_Status* status);
|
||||
static PyObject* TF_DeterminePeakMemoryUsage(
|
||||
GItem item, GCluster cluster,
|
||||
TF_Status* out_status);
|
||||
TF_Status* status);
|
||||
|
@ -24,7 +24,6 @@ from tensorflow.core.framework import step_stats_pb2
|
||||
from tensorflow.core.grappler.costs import op_performance_data_pb2
|
||||
from tensorflow.core.protobuf import device_properties_pb2
|
||||
from tensorflow.python import pywrap_tensorflow as tf_cluster
|
||||
from tensorflow.python.framework import errors
|
||||
|
||||
|
||||
class Cluster(object):
|
||||
@ -49,14 +48,13 @@ class Cluster(object):
|
||||
"""
|
||||
self._tf_cluster = None
|
||||
self._generate_timeline = not disable_timeline
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
|
||||
if devices is None:
|
||||
self._tf_cluster = tf_cluster.TF_NewCluster(
|
||||
allow_soft_placement, disable_detailed_stats, status)
|
||||
self._tf_cluster = tf_cluster.TF_NewCluster(allow_soft_placement,
|
||||
disable_detailed_stats)
|
||||
else:
|
||||
devices_serialized = [device.SerializeToString() for device in devices]
|
||||
self._tf_cluster = tf_cluster.TF_NewVirtualCluster(
|
||||
devices_serialized, status)
|
||||
self._tf_cluster = tf_cluster.TF_NewVirtualCluster(devices_serialized)
|
||||
|
||||
def Shutdown(self):
|
||||
if self._tf_cluster is not None:
|
||||
@ -94,9 +92,8 @@ class Cluster(object):
|
||||
item: The item for which to measure the costs.
|
||||
Returns: The triplet op_perfs, runtime, step_stats.
|
||||
"""
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
ret_from_swig = tf_cluster.TF_MeasureCosts(
|
||||
item.tf_item, self._tf_cluster, self._generate_timeline, status)
|
||||
ret_from_swig = tf_cluster.TF_MeasureCosts(item.tf_item, self._tf_cluster,
|
||||
self._generate_timeline)
|
||||
|
||||
if ret_from_swig is None:
|
||||
return None
|
||||
@ -114,9 +111,8 @@ class Cluster(object):
|
||||
item: The item for which to measure the costs.
|
||||
Returns: A hashtable indexed by device name.
|
||||
"""
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
return tf_cluster.TF_DeterminePeakMemoryUsage(
|
||||
item.tf_item, self._tf_cluster, status)
|
||||
return tf_cluster.TF_DeterminePeakMemoryUsage(item.tf_item,
|
||||
self._tf_cluster)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
|
@ -19,7 +19,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow as tf_wrap
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.grappler import cluster as gcluster
|
||||
from tensorflow.python.grappler import item as gitem
|
||||
|
||||
@ -45,7 +44,6 @@ def GenerateCostReport(metagraph,
|
||||
if cluster is None:
|
||||
cluster = gcluster.Cluster(disable_detailed_stats=False)
|
||||
|
||||
with errors.raise_exception_on_not_ok_status():
|
||||
ret_from_swig = tf_wrap.GenerateCostReport(metagraph.SerializeToString(),
|
||||
per_node_report, verbose,
|
||||
cluster.tf_cluster)
|
||||
|
@ -72,10 +72,10 @@ struct GItem {
|
||||
|
||||
static GItem TF_NewItem(
|
||||
const tensorflow::MetaGraphDef& meta_graph, bool ignore_colocation,
|
||||
bool ignore_user_placement, TF_Status* out_status) {
|
||||
bool ignore_user_placement, TF_Status* status) {
|
||||
if (meta_graph.collection_def().count("train_op") == 0) {
|
||||
tensorflow::Set_TF_Status_from_Status(
|
||||
out_status,
|
||||
status,
|
||||
tensorflow::errors::InvalidArgument("train_op not specified in the metagraph"));
|
||||
return nullptr;
|
||||
}
|
||||
@ -87,11 +87,11 @@ static GItem TF_NewItem(
|
||||
tensorflow::grappler::GrapplerItemFromMetaGraphDef("item", meta_graph, cfg);
|
||||
if (!item) {
|
||||
tensorflow::Set_TF_Status_from_Status(
|
||||
out_status,
|
||||
status,
|
||||
tensorflow::errors::InvalidArgument("Invalid metagraph"));
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::Set_TF_Status_from_Status(out_status, tensorflow::Status::OK());
|
||||
tensorflow::Set_TF_Status_from_Status(status, tensorflow::Status::OK());
|
||||
return GItem(item.release());
|
||||
}
|
||||
|
||||
@ -308,7 +308,7 @@ static PyObject* TF_GetColocationGroups(GItem item) {
|
||||
// Wrap these functions.
|
||||
static GItem TF_NewItem(
|
||||
const tensorflow::MetaGraphDef& meta_graph, bool ignore_colocation,
|
||||
bool ignore_user_placement, TF_Status* out_status);
|
||||
bool ignore_user_placement, TF_Status* status);
|
||||
static PyObject* TF_IdentifyImportantOps(GItem item, bool sort_topologically,
|
||||
TF_Status* status);
|
||||
static PyObject* TF_GetOpProperties(GItem item);
|
||||
|
@ -21,7 +21,6 @@ from __future__ import print_function
|
||||
from tensorflow.core.grappler.costs import op_performance_data_pb2
|
||||
from tensorflow.core.protobuf import meta_graph_pb2
|
||||
from tensorflow.python import pywrap_tensorflow as tf_item
|
||||
from tensorflow.python.framework import errors
|
||||
|
||||
|
||||
class Item(object):
|
||||
@ -87,7 +86,6 @@ class Item(object):
|
||||
return self._tf_item
|
||||
|
||||
def _BuildTFItem(self):
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
self._tf_item = tf_item.TF_NewItem(self._metagraph.SerializeToString(),
|
||||
self._ignore_colocation,
|
||||
self._ignore_user_placement, status)
|
||||
self._ignore_user_placement)
|
||||
|
@ -19,7 +19,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow as tf_wrap
|
||||
from tensorflow.python.framework import errors
|
||||
|
||||
|
||||
def GenerateModelReport(metagraph, assume_valid_feeds=True, debug=False):
|
||||
@ -33,7 +32,6 @@ def GenerateModelReport(metagraph, assume_valid_feeds=True, debug=False):
|
||||
Returns:
|
||||
A string containing the report.
|
||||
"""
|
||||
with errors.raise_exception_on_not_ok_status():
|
||||
ret_from_swig = tf_wrap.GenerateModelReport(metagraph.SerializeToString(),
|
||||
assume_valid_feeds, debug)
|
||||
|
||||
|
@ -95,7 +95,7 @@ PyObject* TF_OptimizeGraph(
|
||||
GCluster cluster,
|
||||
const tensorflow::ConfigProto& config_proto,
|
||||
const tensorflow::MetaGraphDef& metagraph,
|
||||
bool verbose, const string& graph_id, TF_Status* out_status) {
|
||||
bool verbose, const string& graph_id, TF_Status* status) {
|
||||
tensorflow::grappler::ItemConfig item_config;
|
||||
item_config.apply_optimizations = false;
|
||||
item_config.ignore_user_placement = false;
|
||||
@ -103,18 +103,18 @@ PyObject* TF_OptimizeGraph(
|
||||
tensorflow::grappler::GrapplerItemFromMetaGraphDef(graph_id, metagraph, item_config);
|
||||
|
||||
if (!grappler_item) {
|
||||
TF_SetStatus(out_status, TF_INVALID_ARGUMENT, "Failed to import metagraph, check error log for more info.");
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT, "Failed to import metagraph, check error log for more info.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
tensorflow::DeviceBase* cpu_device = nullptr;
|
||||
tensorflow::GraphDef out_graph;
|
||||
tensorflow::grappler::MetaOptimizer optimizer(cpu_device, config_proto);
|
||||
tensorflow::Status status = optimizer.Optimize(cluster.get(), *grappler_item, &out_graph);
|
||||
tensorflow::Status s = optimizer.Optimize(cluster.get(), *grappler_item, &out_graph);
|
||||
if (verbose) {
|
||||
optimizer.PrintResult();
|
||||
}
|
||||
tensorflow::Set_TF_Status_from_Status(out_status, status);
|
||||
tensorflow::Set_TF_Status_from_Status(status, s);
|
||||
string out_graph_str = out_graph.SerializeAsString();
|
||||
PyObject* ret = PyBytes_FromStringAndSize(out_graph_str.data(),
|
||||
out_graph_str.size());
|
||||
@ -128,7 +128,7 @@ PyObject* TF_OptimizeGraph(
|
||||
GCluster cluster,
|
||||
const tensorflow::ConfigProto& config_proto,
|
||||
const tensorflow::MetaGraphDef& metagraph, bool verbose,
|
||||
const string& graph_id, TF_Status* out_status);
|
||||
const string& graph_id, TF_Status* status);
|
||||
|
||||
|
||||
|
||||
|
@ -21,7 +21,6 @@ from __future__ import print_function
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python import pywrap_tensorflow as tf_opt
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.grappler import cluster as gcluster
|
||||
|
||||
|
||||
@ -34,13 +33,12 @@ def OptimizeGraph(config_proto,
|
||||
if not isinstance(config_proto, config_pb2.ConfigProto):
|
||||
raise TypeError('Expected config_proto to be a ConfigProto, saw type %s' %
|
||||
type(config_proto))
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
if cluster is None:
|
||||
cluster = gcluster.Cluster()
|
||||
ret_from_swig = tf_opt.TF_OptimizeGraph(cluster.tf_cluster,
|
||||
config_proto.SerializeToString(),
|
||||
metagraph.SerializeToString(),
|
||||
verbose, graph_id, status)
|
||||
verbose, graph_id)
|
||||
if ret_from_swig is None:
|
||||
return None
|
||||
out_graph = graph_pb2.GraphDef().FromString(ret_from_swig)
|
||||
|
@ -32,125 +32,121 @@ limitations under the License.
|
||||
%}
|
||||
|
||||
%{
|
||||
inline void FileExists(const string& filename, TF_Status* out_status) {
|
||||
tensorflow::Status status = tensorflow::Env::Default()->FileExists(filename);
|
||||
if (!status.ok()) {
|
||||
Set_TF_Status_from_Status(out_status, status);
|
||||
inline void FileExists(const string& filename, TF_Status* status) {
|
||||
tensorflow::Status s = tensorflow::Env::Default()->FileExists(filename);
|
||||
if (!s.ok()) {
|
||||
Set_TF_Status_from_Status(status, s);
|
||||
}
|
||||
}
|
||||
|
||||
inline void FileExists(const tensorflow::StringPiece& filename,
|
||||
TF_Status* out_status) {
|
||||
tensorflow::Status status =
|
||||
TF_Status* status) {
|
||||
tensorflow::Status s =
|
||||
tensorflow::Env::Default()->FileExists(string(filename));
|
||||
if (!status.ok()) {
|
||||
Set_TF_Status_from_Status(out_status, status);
|
||||
if (!s.ok()) {
|
||||
Set_TF_Status_from_Status(status, s);
|
||||
}
|
||||
}
|
||||
|
||||
inline void DeleteFile(const string& filename, TF_Status* out_status) {
|
||||
tensorflow::Status status = tensorflow::Env::Default()->DeleteFile(filename);
|
||||
if (!status.ok()) {
|
||||
Set_TF_Status_from_Status(out_status, status);
|
||||
inline void DeleteFile(const string& filename, TF_Status* status) {
|
||||
tensorflow::Status s = tensorflow::Env::Default()->DeleteFile(filename);
|
||||
if (!s.ok()) {
|
||||
Set_TF_Status_from_Status(status, s);
|
||||
}
|
||||
}
|
||||
|
||||
string ReadFileToString(const string& filename, TF_Status* out_status) {
|
||||
string ReadFileToString(const string& filename, TF_Status* status) {
|
||||
string file_content;
|
||||
tensorflow::Status status = ReadFileToString(tensorflow::Env::Default(),
|
||||
tensorflow::Status s = ReadFileToString(tensorflow::Env::Default(),
|
||||
filename, &file_content);
|
||||
if (!status.ok()) {
|
||||
Set_TF_Status_from_Status(out_status, status);
|
||||
if (!s.ok()) {
|
||||
Set_TF_Status_from_Status(status, s);
|
||||
}
|
||||
return file_content;
|
||||
}
|
||||
|
||||
void WriteStringToFile(const string& filename, const string& file_content,
|
||||
TF_Status* out_status) {
|
||||
tensorflow::Status status = WriteStringToFile(tensorflow::Env::Default(),
|
||||
TF_Status* status) {
|
||||
tensorflow::Status s = WriteStringToFile(tensorflow::Env::Default(),
|
||||
filename, file_content);
|
||||
if (!status.ok()) {
|
||||
Set_TF_Status_from_Status(out_status, status);
|
||||
if (!s.ok()) {
|
||||
Set_TF_Status_from_Status(status, s);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<string> GetChildren(const string& dir, TF_Status* out_status) {
|
||||
std::vector<string> GetChildren(const string& dir, TF_Status* status) {
|
||||
std::vector<string> results;
|
||||
tensorflow::Status status = tensorflow::Env::Default()->GetChildren(
|
||||
tensorflow::Status s = tensorflow::Env::Default()->GetChildren(
|
||||
dir, &results);
|
||||
if (!status.ok()) {
|
||||
Set_TF_Status_from_Status(out_status, status);
|
||||
if (!s.ok()) {
|
||||
Set_TF_Status_from_Status(status, s);
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
std::vector<string> GetMatchingFiles(const string& filename,
|
||||
TF_Status* out_status) {
|
||||
std::vector<string> GetMatchingFiles(const string& filename, TF_Status* status) {
|
||||
std::vector<string> results;
|
||||
tensorflow::Status status = tensorflow::Env::Default()->GetMatchingPaths(
|
||||
tensorflow::Status s = tensorflow::Env::Default()->GetMatchingPaths(
|
||||
filename, &results);
|
||||
if (!status.ok()) {
|
||||
Set_TF_Status_from_Status(out_status, status);
|
||||
if (!s.ok()) {
|
||||
Set_TF_Status_from_Status(status, s);
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
void CreateDir(const string& dirname, TF_Status* out_status) {
|
||||
tensorflow::Status status = tensorflow::Env::Default()->CreateDir(dirname);
|
||||
if (!status.ok() && status.code() != tensorflow::error::ALREADY_EXISTS) {
|
||||
Set_TF_Status_from_Status(out_status, status);
|
||||
void CreateDir(const string& dirname, TF_Status* status) {
|
||||
tensorflow::Status s = tensorflow::Env::Default()->CreateDir(dirname);
|
||||
if (!s.ok() && s.code() != tensorflow::error::ALREADY_EXISTS) {
|
||||
Set_TF_Status_from_Status(status, s);
|
||||
}
|
||||
}
|
||||
|
||||
void RecursivelyCreateDir(const string& dirname, TF_Status* out_status) {
|
||||
tensorflow::Status status = tensorflow::Env::Default()->RecursivelyCreateDir(
|
||||
void RecursivelyCreateDir(const string& dirname, TF_Status* status) {
|
||||
tensorflow::Status s = tensorflow::Env::Default()->RecursivelyCreateDir(
|
||||
dirname);
|
||||
if (!status.ok()) {
|
||||
Set_TF_Status_from_Status(out_status, status);
|
||||
if (!s.ok()) {
|
||||
Set_TF_Status_from_Status(status, s);
|
||||
}
|
||||
}
|
||||
|
||||
void CopyFile(const string& src, const string& target, bool overwrite,
|
||||
TF_Status* out_status) {
|
||||
TF_Status* status) {
|
||||
// If overwrite is false and the target file exists then its an error.
|
||||
if (!overwrite && tensorflow::Env::Default()->FileExists(target).ok()) {
|
||||
TF_SetStatus(out_status, TF_ALREADY_EXISTS, "file already exists");
|
||||
TF_SetStatus(status, TF_ALREADY_EXISTS, "file already exists");
|
||||
return;
|
||||
}
|
||||
tensorflow::Status status =
|
||||
tensorflow::Env::Default()->CopyFile(src, target);
|
||||
if (!status.ok()) {
|
||||
Set_TF_Status_from_Status(out_status, status);
|
||||
tensorflow::Status s = tensorflow::Env::Default()->CopyFile(src, target);
|
||||
if (!s.ok()) {
|
||||
Set_TF_Status_from_Status(status, s);
|
||||
}
|
||||
}
|
||||
|
||||
void RenameFile(const string& src, const string& target, bool overwrite,
|
||||
TF_Status* out_status) {
|
||||
TF_Status* status) {
|
||||
// If overwrite is false and the target file exists then its an error.
|
||||
if (!overwrite && tensorflow::Env::Default()->FileExists(target).ok()) {
|
||||
TF_SetStatus(out_status, TF_ALREADY_EXISTS, "file already exists");
|
||||
TF_SetStatus(status, TF_ALREADY_EXISTS, "file already exists");
|
||||
return;
|
||||
}
|
||||
tensorflow::Status status = tensorflow::Env::Default()->RenameFile(src,
|
||||
target);
|
||||
if (!status.ok()) {
|
||||
Set_TF_Status_from_Status(out_status, status);
|
||||
tensorflow::Status s = tensorflow::Env::Default()->RenameFile(src, target);
|
||||
if (!s.ok()) {
|
||||
Set_TF_Status_from_Status(status, s);
|
||||
}
|
||||
}
|
||||
|
||||
using tensorflow::int64;
|
||||
|
||||
void DeleteRecursively(const string& dirname, TF_Status* out_status) {
|
||||
void DeleteRecursively(const string& dirname, TF_Status* status) {
|
||||
int64 undeleted_files, undeleted_dirs;
|
||||
tensorflow::Status status = tensorflow::Env::Default()->DeleteRecursively(
|
||||
tensorflow::Status s = tensorflow::Env::Default()->DeleteRecursively(
|
||||
dirname, &undeleted_files, &undeleted_dirs);
|
||||
if (!status.ok()) {
|
||||
Set_TF_Status_from_Status(out_status, status);
|
||||
if (!s.ok()) {
|
||||
Set_TF_Status_from_Status(status, s);
|
||||
return;
|
||||
}
|
||||
if (undeleted_files > 0 || undeleted_dirs > 0) {
|
||||
TF_SetStatus(out_status, TF_PERMISSION_DENIED,
|
||||
"could not fully delete dir");
|
||||
TF_SetStatus(status, TF_PERMISSION_DENIED, "could not fully delete dir");
|
||||
return;
|
||||
}
|
||||
}
|
||||
@ -169,22 +165,21 @@ bool IsDirectory(const string& dirname, TF_Status* out_status) {
|
||||
|
||||
using tensorflow::FileStatistics;
|
||||
|
||||
void Stat(const string& filename, FileStatistics* stats,
|
||||
TF_Status* out_status) {
|
||||
tensorflow::Status status = tensorflow::Env::Default()->Stat(filename,
|
||||
void Stat(const string& filename, FileStatistics* stats, TF_Status* status) {
|
||||
tensorflow::Status s = tensorflow::Env::Default()->Stat(filename,
|
||||
stats);
|
||||
if (!status.ok()) {
|
||||
Set_TF_Status_from_Status(out_status, status);
|
||||
if (!s.ok()) {
|
||||
Set_TF_Status_from_Status(status, s);
|
||||
}
|
||||
}
|
||||
|
||||
tensorflow::io::BufferedInputStream* CreateBufferedInputStream(
|
||||
const string& filename, size_t buffer_size, TF_Status* out_status) {
|
||||
const string& filename, size_t buffer_size, TF_Status* status) {
|
||||
std::unique_ptr<tensorflow::RandomAccessFile> file;
|
||||
tensorflow::Status status =
|
||||
tensorflow::Status s =
|
||||
tensorflow::Env::Default()->NewRandomAccessFile(filename, &file);
|
||||
if (!status.ok()) {
|
||||
Set_TF_Status_from_Status(out_status, status);
|
||||
if (!s.ok()) {
|
||||
Set_TF_Status_from_Status(status, s);
|
||||
return nullptr;
|
||||
}
|
||||
std::unique_ptr<tensorflow::io::RandomAccessInputStream> input_stream(
|
||||
@ -197,34 +192,34 @@ tensorflow::io::BufferedInputStream* CreateBufferedInputStream(
|
||||
}
|
||||
|
||||
tensorflow::WritableFile* CreateWritableFile(
|
||||
const string& filename, const string& mode, TF_Status* out_status) {
|
||||
const string& filename, const string& mode, TF_Status* status) {
|
||||
std::unique_ptr<tensorflow::WritableFile> file;
|
||||
tensorflow::Status status;
|
||||
tensorflow::Status s;
|
||||
if (mode.find("a") != std::string::npos) {
|
||||
status = tensorflow::Env::Default()->NewAppendableFile(filename, &file);
|
||||
s = tensorflow::Env::Default()->NewAppendableFile(filename, &file);
|
||||
} else {
|
||||
status = tensorflow::Env::Default()->NewWritableFile(filename, &file);
|
||||
s = tensorflow::Env::Default()->NewWritableFile(filename, &file);
|
||||
}
|
||||
if (!status.ok()) {
|
||||
Set_TF_Status_from_Status(out_status, status);
|
||||
if (!s.ok()) {
|
||||
Set_TF_Status_from_Status(status, s);
|
||||
return nullptr;
|
||||
}
|
||||
return file.release();
|
||||
}
|
||||
|
||||
void AppendToFile(const string& file_content, tensorflow::WritableFile* file,
|
||||
TF_Status* out_status) {
|
||||
tensorflow::Status status = file->Append(file_content);
|
||||
if (!status.ok()) {
|
||||
Set_TF_Status_from_Status(out_status, status);
|
||||
TF_Status* status) {
|
||||
tensorflow::Status s = file->Append(file_content);
|
||||
if (!s.ok()) {
|
||||
Set_TF_Status_from_Status(status, s);
|
||||
}
|
||||
}
|
||||
|
||||
int64 TellFile(tensorflow::WritableFile* file, TF_Status* out_status) {
|
||||
int64 TellFile(tensorflow::WritableFile* file, TF_Status* status) {
|
||||
int64 position = -1;
|
||||
tensorflow::Status status = file->Tell(&position);
|
||||
if (!status.ok()) {
|
||||
Set_TF_Status_from_Status(out_status, status);
|
||||
tensorflow::Status s = file->Tell(&position);
|
||||
if (!s.ok()) {
|
||||
Set_TF_Status_from_Status(status, s);
|
||||
}
|
||||
return position;
|
||||
}
|
||||
@ -232,11 +227,11 @@ int64 TellFile(tensorflow::WritableFile* file, TF_Status* out_status) {
|
||||
|
||||
string ReadFromStream(tensorflow::io::BufferedInputStream* stream,
|
||||
size_t bytes,
|
||||
TF_Status* out_status) {
|
||||
TF_Status* status) {
|
||||
string result;
|
||||
tensorflow::Status status = stream->ReadNBytes(bytes, &result);
|
||||
if (!status.ok() && status.code() != tensorflow::error::OUT_OF_RANGE) {
|
||||
Set_TF_Status_from_Status(out_status, status);
|
||||
tensorflow::Status s = stream->ReadNBytes(bytes, &result);
|
||||
if (!s.ok() && s.code() != tensorflow::error::OUT_OF_RANGE) {
|
||||
Set_TF_Status_from_Status(status, s);
|
||||
result.clear();
|
||||
}
|
||||
return result;
|
||||
@ -250,35 +245,35 @@ string ReadFromStream(tensorflow::io::BufferedInputStream* stream,
|
||||
%newobject CreateWritableFile;
|
||||
|
||||
// Wrap the above functions.
|
||||
inline void FileExists(const string& filename, TF_Status* out_status);
|
||||
inline void DeleteFile(const string& filename, TF_Status* out_status);
|
||||
string ReadFileToString(const string& filename, TF_Status* out_status);
|
||||
inline void FileExists(const string& filename, TF_Status* status);
|
||||
inline void DeleteFile(const string& filename, TF_Status* status);
|
||||
string ReadFileToString(const string& filename, TF_Status* status);
|
||||
void WriteStringToFile(const string& filename, const string& file_content,
|
||||
TF_Status* out_status);
|
||||
std::vector<string> GetChildren(const string& dir, TF_Status* out_status);
|
||||
TF_Status* status);
|
||||
std::vector<string> GetChildren(const string& dir, TF_Status* status);
|
||||
std::vector<string> GetMatchingFiles(const string& filename,
|
||||
TF_Status* out_status);
|
||||
void CreateDir(const string& dirname, TF_Status* out_status);
|
||||
void RecursivelyCreateDir(const string& dirname, TF_Status* out_status);
|
||||
TF_Status* status);
|
||||
void CreateDir(const string& dirname, TF_Status* status);
|
||||
void RecursivelyCreateDir(const string& dirname, TF_Status* status);
|
||||
void CopyFile(const string& oldpath, const string& newpath, bool overwrite,
|
||||
TF_Status* out_status);
|
||||
TF_Status* status);
|
||||
void RenameFile(const string& oldname, const string& newname, bool overwrite,
|
||||
TF_Status* out_status);
|
||||
void DeleteRecursively(const string& dirname, TF_Status* out_status);
|
||||
TF_Status* status);
|
||||
void DeleteRecursively(const string& dirname, TF_Status* status);
|
||||
bool IsDirectory(const string& dirname, TF_Status* out_status);
|
||||
void Stat(const string& filename, tensorflow::FileStatistics* stats,
|
||||
TF_Status* out_status);
|
||||
TF_Status* status);
|
||||
tensorflow::io::BufferedInputStream* CreateBufferedInputStream(
|
||||
const string& filename, size_t buffer_size, TF_Status* out_status);
|
||||
const string& filename, size_t buffer_size, TF_Status* status);
|
||||
tensorflow::WritableFile* CreateWritableFile(const string& filename,
|
||||
const string& mode,
|
||||
TF_Status* out_status);
|
||||
TF_Status* status);
|
||||
void AppendToFile(const string& file_content, tensorflow::WritableFile* file,
|
||||
TF_Status* out_status);
|
||||
int64 TellFile(tensorflow::WritableFile* file, TF_Status* out_status);
|
||||
TF_Status* status);
|
||||
int64 TellFile(tensorflow::WritableFile* file, TF_Status* status);
|
||||
string ReadFromStream(tensorflow::io::BufferedInputStream* stream,
|
||||
size_t bytes,
|
||||
TF_Status* out_status);
|
||||
TF_Status* status);
|
||||
|
||||
%ignore tensorflow::Status::operator=;
|
||||
%include "tensorflow/core/lib/core/status.h"
|
||||
|
@ -80,18 +80,16 @@ class FileIO(object):
|
||||
if not self._read_check_passed:
|
||||
raise errors.PermissionDeniedError(None, None,
|
||||
"File isn't open for reading")
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
self._read_buf = pywrap_tensorflow.CreateBufferedInputStream(
|
||||
compat.as_bytes(self.__name), 1024 * 512, status)
|
||||
compat.as_bytes(self.__name), 1024 * 512)
|
||||
|
||||
def _prewrite_check(self):
|
||||
if not self._writable_file:
|
||||
if not self._write_check_passed:
|
||||
raise errors.PermissionDeniedError(None, None,
|
||||
"File isn't open for writing")
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
self._writable_file = pywrap_tensorflow.CreateWritableFile(
|
||||
compat.as_bytes(self.__name), compat.as_bytes(self.__mode), status)
|
||||
compat.as_bytes(self.__name), compat.as_bytes(self.__mode))
|
||||
|
||||
def _prepare_value(self, val):
|
||||
if self._binary_mode:
|
||||
@ -106,9 +104,8 @@ class FileIO(object):
|
||||
def write(self, file_content):
|
||||
"""Writes file_content to the file. Appends to the end of the file."""
|
||||
self._prewrite_check()
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
pywrap_tensorflow.AppendToFile(
|
||||
compat.as_bytes(file_content), self._writable_file, status)
|
||||
compat.as_bytes(file_content), self._writable_file)
|
||||
|
||||
def read(self, n=-1):
|
||||
"""Returns the contents of a file as a string.
|
||||
@ -123,13 +120,12 @@ class FileIO(object):
|
||||
string if in string (regular) mode.
|
||||
"""
|
||||
self._preread_check()
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
if n == -1:
|
||||
length = self.size() - self.tell()
|
||||
else:
|
||||
length = n
|
||||
return self._prepare_value(
|
||||
pywrap_tensorflow.ReadFromStream(self._read_buf, length, status))
|
||||
pywrap_tensorflow.ReadFromStream(self._read_buf, length))
|
||||
|
||||
@deprecation.deprecated_args(
|
||||
None,
|
||||
@ -202,8 +198,7 @@ class FileIO(object):
|
||||
else:
|
||||
self._prewrite_check()
|
||||
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
return pywrap_tensorflow.TellFile(self._writable_file, status)
|
||||
return pywrap_tensorflow.TellFile(self._writable_file)
|
||||
|
||||
def __enter__(self):
|
||||
"""Make usable with "with" statement."""
|
||||
@ -284,8 +279,7 @@ def file_exists_v2(path):
|
||||
errors.OpError: Propagates any errors reported by the FileSystem API.
|
||||
"""
|
||||
try:
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
pywrap_tensorflow.FileExists(compat.as_bytes(path), status)
|
||||
pywrap_tensorflow.FileExists(compat.as_bytes(path))
|
||||
except errors.NotFoundError:
|
||||
return False
|
||||
return True
|
||||
@ -316,8 +310,7 @@ def delete_file_v2(path):
|
||||
errors.OpError: Propagates any errors reported by the FileSystem API. E.g.,
|
||||
NotFoundError if the path does not exist.
|
||||
"""
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
pywrap_tensorflow.DeleteFile(compat.as_bytes(path), status)
|
||||
pywrap_tensorflow.DeleteFile(compat.as_bytes(path))
|
||||
|
||||
|
||||
def read_file_to_string(filename, binary_mode=False):
|
||||
@ -385,21 +378,20 @@ def get_matching_files_v2(pattern):
|
||||
Raises:
|
||||
errors.OpError: If there are filesystem / directory listing errors.
|
||||
"""
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
if isinstance(pattern, six.string_types):
|
||||
return [
|
||||
# Convert the filenames to string from bytes.
|
||||
compat.as_str_any(matching_filename)
|
||||
for matching_filename in pywrap_tensorflow.GetMatchingFiles(
|
||||
compat.as_bytes(pattern), status)
|
||||
compat.as_bytes(pattern))
|
||||
]
|
||||
else:
|
||||
return [
|
||||
# Convert the filenames to string from bytes.
|
||||
compat.as_str_any(matching_filename)
|
||||
compat.as_str_any(matching_filename) # pylint: disable=g-complex-comprehension
|
||||
for single_filename in pattern
|
||||
for matching_filename in pywrap_tensorflow.GetMatchingFiles(
|
||||
compat.as_bytes(single_filename), status)
|
||||
compat.as_bytes(single_filename))
|
||||
]
|
||||
|
||||
|
||||
@ -434,8 +426,7 @@ def create_dir_v2(path):
|
||||
Raises:
|
||||
errors.OpError: If the operation fails.
|
||||
"""
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
pywrap_tensorflow.CreateDir(compat.as_bytes(path), status)
|
||||
pywrap_tensorflow.CreateDir(compat.as_bytes(path))
|
||||
|
||||
|
||||
@tf_export(v1=["gfile.MakeDirs"])
|
||||
@ -465,8 +456,7 @@ def recursive_create_dir_v2(path):
|
||||
Raises:
|
||||
errors.OpError: If the operation fails.
|
||||
"""
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
pywrap_tensorflow.RecursivelyCreateDir(compat.as_bytes(path), status)
|
||||
pywrap_tensorflow.RecursivelyCreateDir(compat.as_bytes(path))
|
||||
|
||||
|
||||
@tf_export(v1=["gfile.Copy"])
|
||||
@ -498,9 +488,8 @@ def copy_v2(src, dst, overwrite=False):
|
||||
Raises:
|
||||
errors.OpError: If the operation fails.
|
||||
"""
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
pywrap_tensorflow.CopyFile(
|
||||
compat.as_bytes(src), compat.as_bytes(dst), overwrite, status)
|
||||
compat.as_bytes(src), compat.as_bytes(dst), overwrite)
|
||||
|
||||
|
||||
@tf_export(v1=["gfile.Rename"])
|
||||
@ -532,9 +521,8 @@ def rename_v2(src, dst, overwrite=False):
|
||||
Raises:
|
||||
errors.OpError: If the operation fails.
|
||||
"""
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
pywrap_tensorflow.RenameFile(
|
||||
compat.as_bytes(src), compat.as_bytes(dst), overwrite, status)
|
||||
compat.as_bytes(src), compat.as_bytes(dst), overwrite)
|
||||
|
||||
|
||||
def atomic_write_string_to_file(filename, contents, overwrite=True):
|
||||
@ -584,8 +572,7 @@ def delete_recursively_v2(path):
|
||||
Raises:
|
||||
errors.OpError: If the operation fails.
|
||||
"""
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
pywrap_tensorflow.DeleteRecursively(compat.as_bytes(path), status)
|
||||
pywrap_tensorflow.DeleteRecursively(compat.as_bytes(path))
|
||||
|
||||
|
||||
@tf_export(v1=["gfile.IsDirectory"])
|
||||
@ -655,13 +642,12 @@ def list_directory_v2(path):
|
||||
node_def=None,
|
||||
op=None,
|
||||
message="Could not find directory {}".format(path))
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
|
||||
# Convert each element to string, since the return values of the
|
||||
# vector of string should be interpreted as strings, not bytes.
|
||||
return [
|
||||
compat.as_str_any(filename)
|
||||
for filename in pywrap_tensorflow.GetChildren(
|
||||
compat.as_bytes(path), status)
|
||||
for filename in pywrap_tensorflow.GetChildren(compat.as_bytes(path))
|
||||
]
|
||||
|
||||
|
||||
@ -763,8 +749,7 @@ def stat_v2(path):
|
||||
errors.OpError: If the operation fails.
|
||||
"""
|
||||
file_statistics = pywrap_tensorflow.FileStatistics()
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
pywrap_tensorflow.Stat(compat.as_bytes(path), file_statistics, status)
|
||||
pywrap_tensorflow.Stat(compat.as_bytes(path), file_statistics)
|
||||
return file_statistics
|
||||
|
||||
|
||||
|
@ -22,19 +22,19 @@ limitations under the License.
|
||||
static PyObject* DoQuantizeTrainingOnGraphDefHelper(
|
||||
const string& input_graph,
|
||||
int num_bits,
|
||||
TF_Status* out_status) {
|
||||
TF_Status* status) {
|
||||
string result;
|
||||
// TODO(suharshs): Make the QuantizeAndDequantizeV2 configurable.
|
||||
tensorflow::Status status =
|
||||
tensorflow::Status s =
|
||||
tensorflow::DoQuantizeTrainingOnSerializedGraphDef(input_graph, num_bits,
|
||||
"QuantizeAndDequantizeV2", &result);
|
||||
if (!status.ok()) {
|
||||
Set_TF_Status_from_Status(out_status, status);
|
||||
if (!s.ok()) {
|
||||
Set_TF_Status_from_Status(status, s);
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
PyObject* py_str = PyBytes_FromStringAndSize(result.data(), result.size());
|
||||
if (!py_str) {
|
||||
Set_TF_Status_from_Status(out_status,
|
||||
Set_TF_Status_from_Status(status,
|
||||
tensorflow::Status(tensorflow::error::INTERNAL,
|
||||
"Failed to generate serialized string of the rewritten graph."));
|
||||
Py_RETURN_NONE;
|
||||
@ -51,7 +51,7 @@ static PyObject* DoQuantizeTrainingOnGraphDefHelper(
|
||||
PyObject* DoQuantizeTrainingOnGraphDefHelper(
|
||||
const string& input_graph,
|
||||
int num_bits,
|
||||
TF_Status* out_status);
|
||||
TF_Status* status);
|
||||
|
||||
|
||||
%insert("python") %{
|
||||
@ -70,10 +70,10 @@ def do_quantize_training_on_graphdef(input_graph, num_bits):
|
||||
"""
|
||||
from tensorflow.core.framework.graph_pb2 import GraphDef
|
||||
from tensorflow.python.framework import errors
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
|
||||
graph = GraphDef()
|
||||
result_graph_string = DoQuantizeTrainingOnGraphDefHelper(
|
||||
input_graph.SerializeToString(), num_bits, status)
|
||||
input_graph.SerializeToString(), num_bits)
|
||||
|
||||
graph.ParseFromString(result_graph_string)
|
||||
return graph
|
||||
|
@ -104,15 +104,15 @@ limitations under the License.
|
||||
static PyObject* CheckpointReader_GetTensor(
|
||||
tensorflow::checkpoint::CheckpointReader* reader,
|
||||
const string& name,
|
||||
TF_Status* out_status) {
|
||||
TF_Status* status) {
|
||||
PyObject* py_obj = Py_None;
|
||||
std::unique_ptr<tensorflow::Tensor> tensor;
|
||||
reader->GetTensor(name, &tensor, out_status);
|
||||
if (TF_GetCode(out_status) == TF_OK) {
|
||||
tensorflow::Status status =
|
||||
reader->GetTensor(name, &tensor, status);
|
||||
if (TF_GetCode(status) == TF_OK) {
|
||||
tensorflow::Status s =
|
||||
tensorflow::ConvertTensorToNdarray(*tensor.get(), &py_obj);
|
||||
if (!status.ok()) {
|
||||
Set_TF_Status_from_Status(out_status, status);
|
||||
if (!s.ok()) {
|
||||
Set_TF_Status_from_Status(status, s);
|
||||
}
|
||||
}
|
||||
return py_obj;
|
||||
@ -123,7 +123,7 @@ static PyObject* CheckpointReader_GetTensor(
|
||||
PyObject* CheckpointReader_GetTensor(
|
||||
tensorflow::checkpoint::CheckpointReader* reader,
|
||||
const string& name,
|
||||
TF_Status* out_status);
|
||||
TF_Status* status);
|
||||
|
||||
%ignoreall
|
||||
|
||||
@ -150,20 +150,16 @@ PyObject* CheckpointReader_GetTensor(
|
||||
return self._HasTensor(compat.as_bytes(tensor_str))
|
||||
|
||||
def get_tensor(self, tensor_str):
|
||||
from tensorflow.python.framework import errors
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
from tensorflow.python.util import compat
|
||||
return CheckpointReader_GetTensor(self, compat.as_bytes(tensor_str),
|
||||
status)
|
||||
|
||||
return CheckpointReader_GetTensor(self, compat.as_bytes(tensor_str))
|
||||
%}
|
||||
}
|
||||
|
||||
%insert("python") %{
|
||||
def NewCheckpointReader(filepattern):
|
||||
from tensorflow.python.framework import errors
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
from tensorflow.python.util import compat
|
||||
return CheckpointReader(compat.as_bytes(filepattern), status)
|
||||
return CheckpointReader(compat.as_bytes(filepattern))
|
||||
|
||||
NewCheckpointReader._tf_api_names_v1 = ['train.NewCheckpointReader']
|
||||
%}
|
||||
|
Loading…
Reference in New Issue
Block a user