Use TF_Status* SWIG typemap

This avoids the need to use raise_exception_on_not_ok_status
which has been removed from the v2 API.

PiperOrigin-RevId: 244049607
This commit is contained in:
Gaurav Jain 2019-04-17 12:51:53 -07:00 committed by TensorFlower Gardener
parent 4509bdfd5c
commit 3afb32ce1f
24 changed files with 289 additions and 347 deletions

View File

@ -29,8 +29,7 @@ namespace checkpoint {
class TensorSliceReader;
CheckpointReader::CheckpointReader(const string& filename,
TF_Status* out_status)
CheckpointReader::CheckpointReader(const string& filename, TF_Status* status)
: reader_(nullptr),
v2_reader_(nullptr),
var_to_shape_map_(nullptr),
@ -43,7 +42,7 @@ CheckpointReader::CheckpointReader(const string& filename,
v2_reader_.reset(
new BundleReader(Env::Default(), filename /* prefix to a V2 ckpt */));
if (!v2_reader_->status().ok()) {
Set_TF_Status_from_Status(out_status, v2_reader_->status());
Set_TF_Status_from_Status(status, v2_reader_->status());
return;
}
auto result = BuildV2VarMaps();
@ -52,7 +51,7 @@ CheckpointReader::CheckpointReader(const string& filename,
} else {
reader_.reset(new TensorSliceReader(filename));
if (!reader_->status().ok()) {
Set_TF_Status_from_Status(out_status, reader_->status());
Set_TF_Status_from_Status(status, reader_->status());
return;
}
var_to_shape_map_.reset(

View File

@ -39,7 +39,7 @@ class TensorSliceReader;
// variables.
class CheckpointReader {
public:
CheckpointReader(const string& filepattern, TF_Status* out_status);
CheckpointReader(const string& filename, TF_Status* status);
bool HasTensor(const string& name) const;
const string DebugString() const;

View File

@ -44,15 +44,14 @@ namespace tensorflow {
namespace swig {
static std::vector<string> ListDevicesWithSessionConfig(
const tensorflow::ConfigProto& config, TF_Status* out_status) {
const tensorflow::ConfigProto& config, TF_Status* status) {
std::vector<string> output;
SessionOptions options;
options.config = config;
std::vector<std::unique_ptr<Device>> devices;
Status status = DeviceFactory::AddDevices(
options, "" /* name_prefix */, &devices);
if (!status.ok()) {
Set_TF_Status_from_Status(out_status, status);
Status s = DeviceFactory::AddDevices(options, "" /* name_prefix */, &devices);
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
}
for (const std::unique_ptr<Device>& device : devices) {
@ -60,7 +59,7 @@ static std::vector<string> ListDevicesWithSessionConfig(
string attr_serialized;
if (!attr.SerializeToString(&attr_serialized)) {
Set_TF_Status_from_Status(
out_status,
status,
errors::Internal("Could not serialize device string"));
output.clear();
return output;
@ -71,9 +70,9 @@ static std::vector<string> ListDevicesWithSessionConfig(
return output;
}
std::vector<string> ListDevices(TF_Status* out_status) {
std::vector<string> ListDevices(TF_Status* status) {
tensorflow::ConfigProto session_config;
return ListDevicesWithSessionConfig(session_config, out_status);
return ListDevicesWithSessionConfig(session_config, status);
}
} // namespace swig
@ -91,9 +90,9 @@ std::vector<string> ListDevices(TF_Status* out_status) {
// Wrap this function
namespace tensorflow {
namespace swig {
std::vector<string> ListDevices(TF_Status* out_status);
std::vector<string> ListDevices(TF_Status* status);
static std::vector<string> ListDevicesWithSessionConfig(
const tensorflow::ConfigProto& config, TF_Status* out_status);
const tensorflow::ConfigProto& config, TF_Status* status);
} // namespace swig
} // namespace tensorflow
@ -101,12 +100,10 @@ static std::vector<string> ListDevicesWithSessionConfig(
def list_devices(session_config=None):
from tensorflow.python.framework import errors
with errors.raise_exception_on_not_ok_status() as status:
if session_config:
return ListDevicesWithSessionConfig(session_config.SerializeToString(),
status)
else:
return ListDevices(status)
if session_config:
return ListDevicesWithSessionConfig(session_config.SerializeToString())
else:
return ListDevices()
%}
%unignoreall

View File

@ -1425,9 +1425,8 @@ class BaseSession(SessionInterface):
options_ptr = tf_session.TF_NewBufferFromString(
compat.as_bytes(callable_options.SerializeToString()))
try:
with errors.raise_exception_on_not_ok_status() as status:
self._handle = tf_session.TF_SessionMakeCallable(
session._session, options_ptr, status)
self._handle = tf_session.TF_SessionMakeCallable(
session._session, options_ptr)
finally:
tf_session.TF_DeleteBuffer(options_ptr)
@ -1437,11 +1436,9 @@ class BaseSession(SessionInterface):
run_metadata = kwargs.get('run_metadata', None)
try:
run_metadata_ptr = tf_session.TF_NewBuffer() if run_metadata else None
# TODO(mrry): Switch to raising an exception from the SWIG wrapper.
with errors.raise_exception_on_not_ok_status() as status:
ret = tf_session.TF_SessionRunCallable(
self._session._session, self._handle, args, status,
run_metadata_ptr)
ret = tf_session.TF_SessionRunCallable(self._session._session,
self._handle, args,
run_metadata_ptr)
if run_metadata:
proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
run_metadata.ParseFromString(compat.as_bytes(proto_data))
@ -1455,9 +1452,8 @@ class BaseSession(SessionInterface):
# called before this destructor, in which case `self._session._session`
# will be `None`.
if self._handle is not None and self._session._session is not None:
with errors.raise_exception_on_not_ok_status() as status:
tf_session.TF_SessionReleaseCallable(
self._session._session, self._handle, status)
tf_session.TF_SessionReleaseCallable(
self._session._session, self._handle)
# pylint: enable=protected-access
# TODO(b/74355905): Reimplement `Session.make_callable()` using this method

View File

@ -571,8 +571,7 @@ def TF_Reset(target, containers=None, config=None):
from tensorflow.python.framework import errors
opts = TF_NewSessionOptions(target=target, config=config)
try:
with errors.raise_exception_on_not_ok_status() as status:
TF_Reset_wrapper(opts, containers, status)
TF_Reset_wrapper(opts, containers)
finally:
TF_DeleteSessionOptions(opts)
%}

View File

@ -194,16 +194,13 @@ void MakeCallableHelper(tensorflow::Session* session,
void TF_DeprecatedSessionMakeCallable(TF_DeprecatedSession* session,
const TF_Buffer* callable_options,
int64_t* out_handle,
TF_Status* out_status) {
MakeCallableHelper(session->session, callable_options, out_handle,
out_status);
int64_t* out_handle, TF_Status* status) {
MakeCallableHelper(session->session, callable_options, out_handle, status);
}
void TF_SessionMakeCallable(TF_Session* session,
const TF_Buffer* callable_options,
int64_t* out_handle, TF_Status* out_status) {
MakeCallableHelper(session->session, callable_options, out_handle,
out_status);
int64_t* out_handle, TF_Status* status) {
MakeCallableHelper(session->session, callable_options, out_handle, status);
}
namespace {
@ -291,32 +288,28 @@ void RunCallableHelper(tensorflow::Session* session, int64_t handle,
void TF_DeprecatedSessionRunCallable(TF_DeprecatedSession* session,
int64_t handle, PyObject* feed_values,
TF_Status* out_status,
PyObjectVector* out_values,
TF_Buffer* run_metadata) {
RunCallableHelper(session->session, handle, feed_values, out_status,
out_values, run_metadata);
TF_Buffer* run_metadata,
TF_Status* status) {
RunCallableHelper(session->session, handle, feed_values, status, out_values,
run_metadata);
ClearDecrefCache();
}
void TF_SessionRunCallable(TF_Session* session, int64_t handle,
PyObject* feed_values, TF_Status* out_status,
PyObjectVector* out_values,
TF_Buffer* run_metadata) {
RunCallableHelper(session->session, handle, feed_values, out_status,
out_values, run_metadata);
PyObject* feed_values, PyObjectVector* out_values,
TF_Buffer* run_metadata, TF_Status* status) {
RunCallableHelper(session->session, handle, feed_values, status, out_values,
run_metadata);
ClearDecrefCache();
}
void TF_DeprecatedSessionReleaseCallable(TF_DeprecatedSession* session,
int64_t handle,
TF_Status* out_status) {
Set_TF_Status_from_Status(out_status,
session->session->ReleaseCallable(handle));
int64_t handle, TF_Status* status) {
Set_TF_Status_from_Status(status, session->session->ReleaseCallable(handle));
}
void TF_SessionReleaseCallable(TF_Session* session, int64_t handle,
TF_Status* out_status) {
Set_TF_Status_from_Status(out_status,
session->session->ReleaseCallable(handle));
TF_Status* status) {
Set_TF_Status_from_Status(status, session->session->ReleaseCallable(handle));
}
// Wrapper for TF_PRunSetup that converts the arguments to appropriate types.
@ -348,9 +341,9 @@ void TF_PRun_wrapper(TF_DeprecatedSession* session, const char* handle,
// Wrapper for TF_Reset that converts the string vectors to character arrays.
void TF_Reset_wrapper(const TF_SessionOptions* opt,
const NameVector& containers, TF_Status* out_status) {
const NameVector& containers, TF_Status* status) {
TF_Reset(opt, const_cast<const char**>(containers.data()), containers.size(),
out_status);
status);
}
void TF_SessionRun_wrapper_helper(TF_Session* session, const char* handle,

View File

@ -65,27 +65,26 @@ void TF_Run_wrapper(TF_DeprecatedSession* session, const TF_Buffer* run_options,
// Python wrappers for the `Session::MakeCallable()` API.
void TF_DeprecatedSessionMakeCallable(TF_DeprecatedSession* session,
const TF_Buffer* callable_options,
int64_t* out_handle,
TF_Status* out_status);
int64_t* out_handle, TF_Status* status);
void TF_SessionMakeCallable(TF_Session* session,
const TF_Buffer* callable_options,
int64_t* out_handle, TF_Status* out_status);
int64_t* out_handle, TF_Status* status);
// Python wrappers for the `Session::RunCallable()` API.
void TF_DeprecatedSessionRunCallable(TF_DeprecatedSession* session,
int64_t handle, PyObject* feed_values,
TF_Status* out_status,
PyObjectVector* out_values,
TF_Buffer* run_metadata);
TF_Buffer* run_metadata,
TF_Status* status);
void TF_SessionRunCallable(TF_Session* session, int64_t handle,
PyObject* feed_values, TF_Status* out_status,
PyObjectVector* out_values, TF_Buffer* run_metadata);
PyObject* feed_values, PyObjectVector* out_values,
TF_Buffer* run_metadata, TF_Status* status);
// Python wrappers for the `Session::ReleaseCallable()` API.
void TF_DeprecatedSessionReleaseCallable(TF_DeprecatedSession* session,
int64_t handle, TF_Status* out_status);
int64_t handle, TF_Status* status);
void TF_SessionReleaseCallable(TF_Session* session, int64_t handle,
TF_Status* out_status);
TF_Status* status);
// Set up the graph with the intended feeds and fetches for partial run.
// *out_handle is owned by the caller.
@ -118,7 +117,7 @@ void TF_PRun_wrapper(TF_DeprecatedSession* session, const char* handle,
// Wrapper for TF_Reset that converts the string vectors to character arrays.
void TF_Reset_wrapper(const TF_SessionOptions* opt,
const NameVector& containers, TF_Status* out_status);
const NameVector& containers, TF_Status* status);
// Convenience wrapper around EqualGraphDef to make it easier to wrap.
// Returns an explanation if a difference is found, or the empty string

View File

@ -45,7 +45,6 @@ def start_tracing(service_addr,
Raises:
UnavailableError: If no trace event is collected.
"""
# TODO(fishx): Uses errors.raise_exception_on_not_ok_status instead.
if not pywrap_tensorflow.TFE_ProfilerClientStartTracing(
service_addr, logdir, worker_list, include_dataset_ops, duration_ms,
num_tracing_attempts):

View File

@ -692,10 +692,9 @@ def _call_cpp_shape_fn_impl(
missing_shape_fn = False
try:
with errors.raise_exception_on_not_ok_status() as status:
output = pywrap_tensorflow.RunCppShapeInference(
graph_def_version, node_def_str, input_shapes, input_tensors,
input_tensors_as_shapes, status)
output = pywrap_tensorflow.RunCppShapeInference(
graph_def_version, node_def_str, input_shapes, input_tensors,
input_tensors_as_shapes)
except errors.InvalidArgumentError as err:
if err.message.startswith("No shape inference function exists for op"):
missing_shape_fn = True

View File

@ -175,9 +175,9 @@ std::vector<string> RunCppShapeInference(
const std::vector<string>& input_serialized_shapes,
PyObject* input_constant_tensor_values,
const std::vector<string>& input_constant_tensor_as_shape_values,
TF_Status* out_status) {
TF_Status* status) {
if (!PyList_Check(input_constant_tensor_values)) {
TF_SetStatus(out_status, TF_INVALID_ARGUMENT, "Invalid python value");
TF_SetStatus(status, TF_INVALID_ARGUMENT, "Invalid python value");
return std::vector<string>();
}
@ -191,13 +191,13 @@ std::vector<string> RunCppShapeInference(
std::vector<string> output;
string input_tensors_needed_out;
tensorflow::Status status = RunCppShapeInferenceImpl(
tensorflow::Status s = RunCppShapeInferenceImpl(
graph_def_version, serialized_node_def, input_serialized_shapes,
input_constant_tensor_values_v, input_constant_tensor_as_shape_values,
&output, &input_tensors_needed_out);
Set_TF_Status_from_Status(out_status, status);
if (!status.ok()) {
Set_TF_Status_from_Status(status, s);
if (!s.ok()) {
return std::vector<string>();
}
output.push_back(input_tensors_needed_out);

View File

@ -46,7 +46,7 @@ std::vector<string> RunCppShapeInference(
const std::vector<string>& input_serialized_shapes,
PyObject* input_constant_tensor_values,
const std::vector<string>& input_constant_tensor_as_shape_values,
TF_Status* out_status);
TF_Status* status);
} // namespace swig
} // namespace tensorflow

View File

@ -112,9 +112,7 @@ class ErrorsTest(test.TestCase):
def testStatusDoesNotLeak(self):
try:
with errors.raise_exception_on_not_ok_status() as status:
pywrap_tensorflow.DeleteFile(
compat.as_bytes("/DOES_NOT_EXIST/"), status)
pywrap_tensorflow.DeleteFile(compat.as_bytes("/DOES_NOT_EXIST/"))
except:
pass
gc.collect()

View File

@ -131,7 +131,7 @@ struct GCluster {
static GCluster TF_NewCluster(bool allow_soft_placement,
bool disable_detailed_stats, TF_Status* out_status) {
bool disable_detailed_stats, TF_Status* status) {
int num_cpu_cores = tensorflow::grappler::GetNumAvailableLogicalCPUCores();
int num_gpus = tensorflow::grappler::GetNumAvailableGPUs();
int timeout_s = 60 * 10;
@ -141,14 +141,13 @@ static GCluster TF_NewCluster(bool allow_soft_placement,
cluster_->DisableDetailedStats(disable_detailed_stats);
cluster_->AllowSoftPlacement(allow_soft_placement);
cluster_->SetNumWarmupSteps(10);
tensorflow::Status status = cluster_->Provision();
tensorflow::Set_TF_Status_from_Status(out_status, status);
tensorflow::Status s = cluster_->Provision();
tensorflow::Set_TF_Status_from_Status(status, s);
return GCluster(cluster_);
}
static GCluster TF_NewVirtualCluster(
const std::vector<tensorflow::NamedDevice>& named_devices,
TF_Status* out_status) {
const std::vector<tensorflow::NamedDevice>& named_devices, TF_Status* status) {
std::unordered_map<string, tensorflow::DeviceProperties> devices;
for (const auto& named_device : named_devices) {
devices[named_device.name()]= named_device.properties();
@ -156,9 +155,9 @@ static GCluster TF_NewVirtualCluster(
tensorflow::grappler::Cluster* cluster_ =
new tensorflow::grappler::VirtualCluster(devices);
PyGILState_STATE gstate = PyGILState_Ensure();
tensorflow::Status status = cluster_->Provision();
tensorflow::Status s = cluster_->Provision();
PyGILState_Release(gstate);
tensorflow::Set_TF_Status_from_Status(out_status, status);
tensorflow::Set_TF_Status_from_Status(status, s);
return GCluster(cluster_);
}
@ -316,7 +315,7 @@ static double TF_EstimatePerformance(const tensorflow::NamedDevice& device) {
static PyObject* TF_MeasureCosts(
GItem item,
GCluster cluster,
bool generate_timeline, TF_Status* out_status) {
bool generate_timeline, TF_Status* status) {
tensorflow::OpPerformanceList op_performance_data;
tensorflow::StepStats step_stats;
@ -324,25 +323,25 @@ static PyObject* TF_MeasureCosts(
tensorflow::grappler::MeasuringCostEstimator cost_measure(cluster.get(), num_measurements, 0);
tensorflow::grappler::Costs costs;
tensorflow::Status status = _GetOpPerformanceDataAndRunTime(
tensorflow::Status s = _GetOpPerformanceDataAndRunTime(
*item, &cost_measure, &op_performance_data, &costs);
double run_time = FLT_MAX;
if (status.ok()) {
if (s.ok()) {
run_time = static_cast<double>(costs.execution_time.count()) / 1e9;
}
if (generate_timeline) {
tensorflow::RunMetadata metadata;
tensorflow::Status s = cluster->Run(
tensorflow::Status run_status = cluster->Run(
item->graph, item->feed, item->fetch, &metadata);
if (s.ok()) {
if (run_status.ok()) {
step_stats = metadata.step_stats();
} else {
status = s;
s = run_status;
}
}
tensorflow::Set_TF_Status_from_Status(out_status, status);
if (!status.ok()) {
tensorflow::Set_TF_Status_from_Status(status, s);
if (!s.ok()) {
Py_RETURN_NONE;
}
PyGILState_STATE gstate = PyGILState_Ensure();
@ -370,9 +369,9 @@ static PyObject* TF_MeasureCosts(
Py_XDECREF(op_perf_objs);
Py_XDECREF(run_time_obj);
Py_XDECREF(metadata_obj);
status = tensorflow::Status(tensorflow::error::Code::INTERNAL,
"Error setting return tuples.");
tensorflow::Set_TF_Status_from_Status(out_status, status);
s = tensorflow::Status(tensorflow::error::Code::INTERNAL,
"Error setting return tuples.");
tensorflow::Set_TF_Status_from_Status(status, s);
Py_INCREF(Py_None);
ret = Py_None;
}
@ -384,23 +383,23 @@ static PyObject* TF_MeasureCosts(
static PyObject* TF_DeterminePeakMemoryUsage(
GItem item,
GCluster cluster,
TF_Status* out_status) {
TF_Status* status) {
if (item.is_none() || cluster.is_none()) {
tensorflow::Status status(tensorflow::error::Code::INTERNAL,
"You need both a cluster and an item to determine peak memory usage");
tensorflow::Set_TF_Status_from_Status(out_status, status);
tensorflow::Status s(tensorflow::error::Code::INTERNAL,
"You need both a cluster and an item to determine peak memory usage");
tensorflow::Set_TF_Status_from_Status(status, s);
Py_RETURN_NONE;
}
tensorflow::grappler::GraphMemory memory(*item);
tensorflow::Status status;
tensorflow::Status s;
if (cluster->DetailedStatsEnabled()) {
status = memory.InferDynamically(cluster.get());
s = memory.InferDynamically(cluster.get());
} else {
status = memory.InferStatically(cluster->GetDevices());
s = memory.InferStatically(cluster->GetDevices());
}
if (!status.ok()) {
tensorflow::Set_TF_Status_from_Status(out_status, status);
if (!s.ok()) {
tensorflow::Set_TF_Status_from_Status(status, s);
Py_RETURN_NONE;
}
@ -434,10 +433,10 @@ static PyObject* TF_DeterminePeakMemoryUsage(
// Wrap these functions.
static GCluster TF_NewCluster(
bool allow_soft_placement, bool disable_detailed_stats, TF_Status* out_status);
bool allow_soft_placement, bool disable_detailed_stats, TF_Status* status);
static GCluster TF_NewVirtualCluster(
const std::vector<tensorflow::NamedDevice>& named_devices,
TF_Status* out_status);
TF_Status* status);
static void TF_ShutdownCluster(GCluster cluster);
static PyObject* TF_ListDevices(GCluster cluster);
static PyObject* TF_ListAvailableOps();
@ -445,7 +444,7 @@ static PyObject* TF_GetSupportedDevices(GCluster cluster, GItem item);
static float TF_EstimatePerformance(const tensorflow::NamedDevice& device);
static PyObject* TF_MeasureCosts(
GItem item, GCluster cluster,
bool generate_timeline, TF_Status* out_status);
bool generate_timeline, TF_Status* status);
static PyObject* TF_DeterminePeakMemoryUsage(
GItem item, GCluster cluster,
TF_Status* out_status);
TF_Status* status);

View File

@ -24,7 +24,6 @@ from tensorflow.core.framework import step_stats_pb2
from tensorflow.core.grappler.costs import op_performance_data_pb2
from tensorflow.core.protobuf import device_properties_pb2
from tensorflow.python import pywrap_tensorflow as tf_cluster
from tensorflow.python.framework import errors
class Cluster(object):
@ -49,14 +48,13 @@ class Cluster(object):
"""
self._tf_cluster = None
self._generate_timeline = not disable_timeline
with errors.raise_exception_on_not_ok_status() as status:
if devices is None:
self._tf_cluster = tf_cluster.TF_NewCluster(
allow_soft_placement, disable_detailed_stats, status)
else:
devices_serialized = [device.SerializeToString() for device in devices]
self._tf_cluster = tf_cluster.TF_NewVirtualCluster(
devices_serialized, status)
if devices is None:
self._tf_cluster = tf_cluster.TF_NewCluster(allow_soft_placement,
disable_detailed_stats)
else:
devices_serialized = [device.SerializeToString() for device in devices]
self._tf_cluster = tf_cluster.TF_NewVirtualCluster(devices_serialized)
def Shutdown(self):
if self._tf_cluster is not None:
@ -94,9 +92,8 @@ class Cluster(object):
item: The item for which to measure the costs.
Returns: The triplet op_perfs, runtime, step_stats.
"""
with errors.raise_exception_on_not_ok_status() as status:
ret_from_swig = tf_cluster.TF_MeasureCosts(
item.tf_item, self._tf_cluster, self._generate_timeline, status)
ret_from_swig = tf_cluster.TF_MeasureCosts(item.tf_item, self._tf_cluster,
self._generate_timeline)
if ret_from_swig is None:
return None
@ -114,9 +111,8 @@ class Cluster(object):
item: The item for which to measure the costs.
Returns: A hashtable indexed by device name.
"""
with errors.raise_exception_on_not_ok_status() as status:
return tf_cluster.TF_DeterminePeakMemoryUsage(
item.tf_item, self._tf_cluster, status)
return tf_cluster.TF_DeterminePeakMemoryUsage(item.tf_item,
self._tf_cluster)
@contextlib.contextmanager

View File

@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python import pywrap_tensorflow as tf_wrap
from tensorflow.python.framework import errors
from tensorflow.python.grappler import cluster as gcluster
from tensorflow.python.grappler import item as gitem
@ -45,10 +44,9 @@ def GenerateCostReport(metagraph,
if cluster is None:
cluster = gcluster.Cluster(disable_detailed_stats=False)
with errors.raise_exception_on_not_ok_status():
ret_from_swig = tf_wrap.GenerateCostReport(metagraph.SerializeToString(),
per_node_report, verbose,
cluster.tf_cluster)
ret_from_swig = tf_wrap.GenerateCostReport(metagraph.SerializeToString(),
per_node_report, verbose,
cluster.tf_cluster)
return ret_from_swig

View File

@ -72,10 +72,10 @@ struct GItem {
static GItem TF_NewItem(
const tensorflow::MetaGraphDef& meta_graph, bool ignore_colocation,
bool ignore_user_placement, TF_Status* out_status) {
bool ignore_user_placement, TF_Status* status) {
if (meta_graph.collection_def().count("train_op") == 0) {
tensorflow::Set_TF_Status_from_Status(
out_status,
status,
tensorflow::errors::InvalidArgument("train_op not specified in the metagraph"));
return nullptr;
}
@ -87,11 +87,11 @@ static GItem TF_NewItem(
tensorflow::grappler::GrapplerItemFromMetaGraphDef("item", meta_graph, cfg);
if (!item) {
tensorflow::Set_TF_Status_from_Status(
out_status,
status,
tensorflow::errors::InvalidArgument("Invalid metagraph"));
return nullptr;
}
tensorflow::Set_TF_Status_from_Status(out_status, tensorflow::Status::OK());
tensorflow::Set_TF_Status_from_Status(status, tensorflow::Status::OK());
return GItem(item.release());
}
@ -308,7 +308,7 @@ static PyObject* TF_GetColocationGroups(GItem item) {
// Wrap these functions.
static GItem TF_NewItem(
const tensorflow::MetaGraphDef& meta_graph, bool ignore_colocation,
bool ignore_user_placement, TF_Status* out_status);
bool ignore_user_placement, TF_Status* status);
static PyObject* TF_IdentifyImportantOps(GItem item, bool sort_topologically,
TF_Status* status);
static PyObject* TF_GetOpProperties(GItem item);

View File

@ -21,7 +21,6 @@ from __future__ import print_function
from tensorflow.core.grappler.costs import op_performance_data_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.python import pywrap_tensorflow as tf_item
from tensorflow.python.framework import errors
class Item(object):
@ -87,7 +86,6 @@ class Item(object):
return self._tf_item
def _BuildTFItem(self):
with errors.raise_exception_on_not_ok_status() as status:
self._tf_item = tf_item.TF_NewItem(self._metagraph.SerializeToString(),
self._ignore_colocation,
self._ignore_user_placement, status)
self._tf_item = tf_item.TF_NewItem(self._metagraph.SerializeToString(),
self._ignore_colocation,
self._ignore_user_placement)

View File

@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python import pywrap_tensorflow as tf_wrap
from tensorflow.python.framework import errors
def GenerateModelReport(metagraph, assume_valid_feeds=True, debug=False):
@ -33,8 +32,7 @@ def GenerateModelReport(metagraph, assume_valid_feeds=True, debug=False):
Returns:
A string containing the report.
"""
with errors.raise_exception_on_not_ok_status():
ret_from_swig = tf_wrap.GenerateModelReport(metagraph.SerializeToString(),
assume_valid_feeds, debug)
ret_from_swig = tf_wrap.GenerateModelReport(metagraph.SerializeToString(),
assume_valid_feeds, debug)
return ret_from_swig

View File

@ -95,7 +95,7 @@ PyObject* TF_OptimizeGraph(
GCluster cluster,
const tensorflow::ConfigProto& config_proto,
const tensorflow::MetaGraphDef& metagraph,
bool verbose, const string& graph_id, TF_Status* out_status) {
bool verbose, const string& graph_id, TF_Status* status) {
tensorflow::grappler::ItemConfig item_config;
item_config.apply_optimizations = false;
item_config.ignore_user_placement = false;
@ -103,18 +103,18 @@ PyObject* TF_OptimizeGraph(
tensorflow::grappler::GrapplerItemFromMetaGraphDef(graph_id, metagraph, item_config);
if (!grappler_item) {
TF_SetStatus(out_status, TF_INVALID_ARGUMENT, "Failed to import metagraph, check error log for more info.");
TF_SetStatus(status, TF_INVALID_ARGUMENT, "Failed to import metagraph, check error log for more info.");
return nullptr;
}
tensorflow::DeviceBase* cpu_device = nullptr;
tensorflow::GraphDef out_graph;
tensorflow::grappler::MetaOptimizer optimizer(cpu_device, config_proto);
tensorflow::Status status = optimizer.Optimize(cluster.get(), *grappler_item, &out_graph);
tensorflow::Status s = optimizer.Optimize(cluster.get(), *grappler_item, &out_graph);
if (verbose) {
optimizer.PrintResult();
}
tensorflow::Set_TF_Status_from_Status(out_status, status);
tensorflow::Set_TF_Status_from_Status(status, s);
string out_graph_str = out_graph.SerializeAsString();
PyObject* ret = PyBytes_FromStringAndSize(out_graph_str.data(),
out_graph_str.size());
@ -128,7 +128,7 @@ PyObject* TF_OptimizeGraph(
GCluster cluster,
const tensorflow::ConfigProto& config_proto,
const tensorflow::MetaGraphDef& metagraph, bool verbose,
const string& graph_id, TF_Status* out_status);
const string& graph_id, TF_Status* status);

View File

@ -21,7 +21,6 @@ from __future__ import print_function
from tensorflow.core.framework import graph_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.python import pywrap_tensorflow as tf_opt
from tensorflow.python.framework import errors
from tensorflow.python.grappler import cluster as gcluster
@ -34,13 +33,12 @@ def OptimizeGraph(config_proto,
if not isinstance(config_proto, config_pb2.ConfigProto):
raise TypeError('Expected config_proto to be a ConfigProto, saw type %s' %
type(config_proto))
with errors.raise_exception_on_not_ok_status() as status:
if cluster is None:
cluster = gcluster.Cluster()
ret_from_swig = tf_opt.TF_OptimizeGraph(cluster.tf_cluster,
config_proto.SerializeToString(),
metagraph.SerializeToString(),
verbose, graph_id, status)
if cluster is None:
cluster = gcluster.Cluster()
ret_from_swig = tf_opt.TF_OptimizeGraph(cluster.tf_cluster,
config_proto.SerializeToString(),
metagraph.SerializeToString(),
verbose, graph_id)
if ret_from_swig is None:
return None
out_graph = graph_pb2.GraphDef().FromString(ret_from_swig)

View File

@ -32,125 +32,121 @@ limitations under the License.
%}
%{
inline void FileExists(const string& filename, TF_Status* out_status) {
tensorflow::Status status = tensorflow::Env::Default()->FileExists(filename);
if (!status.ok()) {
Set_TF_Status_from_Status(out_status, status);
inline void FileExists(const string& filename, TF_Status* status) {
tensorflow::Status s = tensorflow::Env::Default()->FileExists(filename);
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
}
}
inline void FileExists(const tensorflow::StringPiece& filename,
TF_Status* out_status) {
tensorflow::Status status =
TF_Status* status) {
tensorflow::Status s =
tensorflow::Env::Default()->FileExists(string(filename));
if (!status.ok()) {
Set_TF_Status_from_Status(out_status, status);
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
}
}
inline void DeleteFile(const string& filename, TF_Status* out_status) {
tensorflow::Status status = tensorflow::Env::Default()->DeleteFile(filename);
if (!status.ok()) {
Set_TF_Status_from_Status(out_status, status);
inline void DeleteFile(const string& filename, TF_Status* status) {
tensorflow::Status s = tensorflow::Env::Default()->DeleteFile(filename);
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
}
}
string ReadFileToString(const string& filename, TF_Status* out_status) {
string ReadFileToString(const string& filename, TF_Status* status) {
string file_content;
tensorflow::Status status = ReadFileToString(tensorflow::Env::Default(),
tensorflow::Status s = ReadFileToString(tensorflow::Env::Default(),
filename, &file_content);
if (!status.ok()) {
Set_TF_Status_from_Status(out_status, status);
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
}
return file_content;
}
void WriteStringToFile(const string& filename, const string& file_content,
TF_Status* out_status) {
tensorflow::Status status = WriteStringToFile(tensorflow::Env::Default(),
TF_Status* status) {
tensorflow::Status s = WriteStringToFile(tensorflow::Env::Default(),
filename, file_content);
if (!status.ok()) {
Set_TF_Status_from_Status(out_status, status);
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
}
}
std::vector<string> GetChildren(const string& dir, TF_Status* out_status) {
std::vector<string> GetChildren(const string& dir, TF_Status* status) {
std::vector<string> results;
tensorflow::Status status = tensorflow::Env::Default()->GetChildren(
tensorflow::Status s = tensorflow::Env::Default()->GetChildren(
dir, &results);
if (!status.ok()) {
Set_TF_Status_from_Status(out_status, status);
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
}
return results;
}
std::vector<string> GetMatchingFiles(const string& filename,
TF_Status* out_status) {
std::vector<string> GetMatchingFiles(const string& filename, TF_Status* status) {
std::vector<string> results;
tensorflow::Status status = tensorflow::Env::Default()->GetMatchingPaths(
tensorflow::Status s = tensorflow::Env::Default()->GetMatchingPaths(
filename, &results);
if (!status.ok()) {
Set_TF_Status_from_Status(out_status, status);
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
}
return results;
}
void CreateDir(const string& dirname, TF_Status* out_status) {
tensorflow::Status status = tensorflow::Env::Default()->CreateDir(dirname);
if (!status.ok() && status.code() != tensorflow::error::ALREADY_EXISTS) {
Set_TF_Status_from_Status(out_status, status);
void CreateDir(const string& dirname, TF_Status* status) {
tensorflow::Status s = tensorflow::Env::Default()->CreateDir(dirname);
if (!s.ok() && s.code() != tensorflow::error::ALREADY_EXISTS) {
Set_TF_Status_from_Status(status, s);
}
}
void RecursivelyCreateDir(const string& dirname, TF_Status* out_status) {
tensorflow::Status status = tensorflow::Env::Default()->RecursivelyCreateDir(
void RecursivelyCreateDir(const string& dirname, TF_Status* status) {
tensorflow::Status s = tensorflow::Env::Default()->RecursivelyCreateDir(
dirname);
if (!status.ok()) {
Set_TF_Status_from_Status(out_status, status);
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
}
}
void CopyFile(const string& src, const string& target, bool overwrite,
TF_Status* out_status) {
TF_Status* status) {
// If overwrite is false and the target file exists then its an error.
if (!overwrite && tensorflow::Env::Default()->FileExists(target).ok()) {
TF_SetStatus(out_status, TF_ALREADY_EXISTS, "file already exists");
TF_SetStatus(status, TF_ALREADY_EXISTS, "file already exists");
return;
}
tensorflow::Status status =
tensorflow::Env::Default()->CopyFile(src, target);
if (!status.ok()) {
Set_TF_Status_from_Status(out_status, status);
tensorflow::Status s = tensorflow::Env::Default()->CopyFile(src, target);
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
}
}
void RenameFile(const string& src, const string& target, bool overwrite,
TF_Status* out_status) {
TF_Status* status) {
// If overwrite is false and the target file exists then its an error.
if (!overwrite && tensorflow::Env::Default()->FileExists(target).ok()) {
TF_SetStatus(out_status, TF_ALREADY_EXISTS, "file already exists");
TF_SetStatus(status, TF_ALREADY_EXISTS, "file already exists");
return;
}
tensorflow::Status status = tensorflow::Env::Default()->RenameFile(src,
target);
if (!status.ok()) {
Set_TF_Status_from_Status(out_status, status);
tensorflow::Status s = tensorflow::Env::Default()->RenameFile(src, target);
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
}
}
using tensorflow::int64;
void DeleteRecursively(const string& dirname, TF_Status* out_status) {
void DeleteRecursively(const string& dirname, TF_Status* status) {
int64 undeleted_files, undeleted_dirs;
tensorflow::Status status = tensorflow::Env::Default()->DeleteRecursively(
tensorflow::Status s = tensorflow::Env::Default()->DeleteRecursively(
dirname, &undeleted_files, &undeleted_dirs);
if (!status.ok()) {
Set_TF_Status_from_Status(out_status, status);
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
return;
}
if (undeleted_files > 0 || undeleted_dirs > 0) {
TF_SetStatus(out_status, TF_PERMISSION_DENIED,
"could not fully delete dir");
TF_SetStatus(status, TF_PERMISSION_DENIED, "could not fully delete dir");
return;
}
}
@ -169,22 +165,21 @@ bool IsDirectory(const string& dirname, TF_Status* out_status) {
using tensorflow::FileStatistics;
void Stat(const string& filename, FileStatistics* stats,
TF_Status* out_status) {
tensorflow::Status status = tensorflow::Env::Default()->Stat(filename,
void Stat(const string& filename, FileStatistics* stats, TF_Status* status) {
tensorflow::Status s = tensorflow::Env::Default()->Stat(filename,
stats);
if (!status.ok()) {
Set_TF_Status_from_Status(out_status, status);
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
}
}
tensorflow::io::BufferedInputStream* CreateBufferedInputStream(
const string& filename, size_t buffer_size, TF_Status* out_status) {
const string& filename, size_t buffer_size, TF_Status* status) {
std::unique_ptr<tensorflow::RandomAccessFile> file;
tensorflow::Status status =
tensorflow::Status s =
tensorflow::Env::Default()->NewRandomAccessFile(filename, &file);
if (!status.ok()) {
Set_TF_Status_from_Status(out_status, status);
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
return nullptr;
}
std::unique_ptr<tensorflow::io::RandomAccessInputStream> input_stream(
@ -197,34 +192,34 @@ tensorflow::io::BufferedInputStream* CreateBufferedInputStream(
}
tensorflow::WritableFile* CreateWritableFile(
const string& filename, const string& mode, TF_Status* out_status) {
const string& filename, const string& mode, TF_Status* status) {
std::unique_ptr<tensorflow::WritableFile> file;
tensorflow::Status status;
tensorflow::Status s;
if (mode.find("a") != std::string::npos) {
status = tensorflow::Env::Default()->NewAppendableFile(filename, &file);
s = tensorflow::Env::Default()->NewAppendableFile(filename, &file);
} else {
status = tensorflow::Env::Default()->NewWritableFile(filename, &file);
s = tensorflow::Env::Default()->NewWritableFile(filename, &file);
}
if (!status.ok()) {
Set_TF_Status_from_Status(out_status, status);
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
return nullptr;
}
return file.release();
}
void AppendToFile(const string& file_content, tensorflow::WritableFile* file,
TF_Status* out_status) {
tensorflow::Status status = file->Append(file_content);
if (!status.ok()) {
Set_TF_Status_from_Status(out_status, status);
TF_Status* status) {
tensorflow::Status s = file->Append(file_content);
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
}
}
int64 TellFile(tensorflow::WritableFile* file, TF_Status* out_status) {
int64 TellFile(tensorflow::WritableFile* file, TF_Status* status) {
int64 position = -1;
tensorflow::Status status = file->Tell(&position);
if (!status.ok()) {
Set_TF_Status_from_Status(out_status, status);
tensorflow::Status s = file->Tell(&position);
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
}
return position;
}
@ -232,11 +227,11 @@ int64 TellFile(tensorflow::WritableFile* file, TF_Status* out_status) {
string ReadFromStream(tensorflow::io::BufferedInputStream* stream,
size_t bytes,
TF_Status* out_status) {
TF_Status* status) {
string result;
tensorflow::Status status = stream->ReadNBytes(bytes, &result);
if (!status.ok() && status.code() != tensorflow::error::OUT_OF_RANGE) {
Set_TF_Status_from_Status(out_status, status);
tensorflow::Status s = stream->ReadNBytes(bytes, &result);
if (!s.ok() && s.code() != tensorflow::error::OUT_OF_RANGE) {
Set_TF_Status_from_Status(status, s);
result.clear();
}
return result;
@ -250,35 +245,35 @@ string ReadFromStream(tensorflow::io::BufferedInputStream* stream,
%newobject CreateWritableFile;
// Wrap the above functions.
inline void FileExists(const string& filename, TF_Status* out_status);
inline void DeleteFile(const string& filename, TF_Status* out_status);
string ReadFileToString(const string& filename, TF_Status* out_status);
inline void FileExists(const string& filename, TF_Status* status);
inline void DeleteFile(const string& filename, TF_Status* status);
string ReadFileToString(const string& filename, TF_Status* status);
void WriteStringToFile(const string& filename, const string& file_content,
TF_Status* out_status);
std::vector<string> GetChildren(const string& dir, TF_Status* out_status);
TF_Status* status);
std::vector<string> GetChildren(const string& dir, TF_Status* status);
std::vector<string> GetMatchingFiles(const string& filename,
TF_Status* out_status);
void CreateDir(const string& dirname, TF_Status* out_status);
void RecursivelyCreateDir(const string& dirname, TF_Status* out_status);
TF_Status* status);
void CreateDir(const string& dirname, TF_Status* status);
void RecursivelyCreateDir(const string& dirname, TF_Status* status);
void CopyFile(const string& oldpath, const string& newpath, bool overwrite,
TF_Status* out_status);
TF_Status* status);
void RenameFile(const string& oldname, const string& newname, bool overwrite,
TF_Status* out_status);
void DeleteRecursively(const string& dirname, TF_Status* out_status);
TF_Status* status);
void DeleteRecursively(const string& dirname, TF_Status* status);
bool IsDirectory(const string& dirname, TF_Status* out_status);
void Stat(const string& filename, tensorflow::FileStatistics* stats,
TF_Status* out_status);
TF_Status* status);
tensorflow::io::BufferedInputStream* CreateBufferedInputStream(
const string& filename, size_t buffer_size, TF_Status* out_status);
const string& filename, size_t buffer_size, TF_Status* status);
tensorflow::WritableFile* CreateWritableFile(const string& filename,
const string& mode,
TF_Status* out_status);
TF_Status* status);
void AppendToFile(const string& file_content, tensorflow::WritableFile* file,
TF_Status* out_status);
int64 TellFile(tensorflow::WritableFile* file, TF_Status* out_status);
TF_Status* status);
int64 TellFile(tensorflow::WritableFile* file, TF_Status* status);
string ReadFromStream(tensorflow::io::BufferedInputStream* stream,
size_t bytes,
TF_Status* out_status);
TF_Status* status);
%ignore tensorflow::Status::operator=;
%include "tensorflow/core/lib/core/status.h"

View File

@ -80,18 +80,16 @@ class FileIO(object):
if not self._read_check_passed:
raise errors.PermissionDeniedError(None, None,
"File isn't open for reading")
with errors.raise_exception_on_not_ok_status() as status:
self._read_buf = pywrap_tensorflow.CreateBufferedInputStream(
compat.as_bytes(self.__name), 1024 * 512, status)
self._read_buf = pywrap_tensorflow.CreateBufferedInputStream(
compat.as_bytes(self.__name), 1024 * 512)
def _prewrite_check(self):
if not self._writable_file:
if not self._write_check_passed:
raise errors.PermissionDeniedError(None, None,
"File isn't open for writing")
with errors.raise_exception_on_not_ok_status() as status:
self._writable_file = pywrap_tensorflow.CreateWritableFile(
compat.as_bytes(self.__name), compat.as_bytes(self.__mode), status)
self._writable_file = pywrap_tensorflow.CreateWritableFile(
compat.as_bytes(self.__name), compat.as_bytes(self.__mode))
def _prepare_value(self, val):
if self._binary_mode:
@ -106,9 +104,8 @@ class FileIO(object):
def write(self, file_content):
"""Writes file_content to the file. Appends to the end of the file."""
self._prewrite_check()
with errors.raise_exception_on_not_ok_status() as status:
pywrap_tensorflow.AppendToFile(
compat.as_bytes(file_content), self._writable_file, status)
pywrap_tensorflow.AppendToFile(
compat.as_bytes(file_content), self._writable_file)
def read(self, n=-1):
"""Returns the contents of a file as a string.
@ -123,13 +120,12 @@ class FileIO(object):
string if in string (regular) mode.
"""
self._preread_check()
with errors.raise_exception_on_not_ok_status() as status:
if n == -1:
length = self.size() - self.tell()
else:
length = n
return self._prepare_value(
pywrap_tensorflow.ReadFromStream(self._read_buf, length, status))
if n == -1:
length = self.size() - self.tell()
else:
length = n
return self._prepare_value(
pywrap_tensorflow.ReadFromStream(self._read_buf, length))
@deprecation.deprecated_args(
None,
@ -202,8 +198,7 @@ class FileIO(object):
else:
self._prewrite_check()
with errors.raise_exception_on_not_ok_status() as status:
return pywrap_tensorflow.TellFile(self._writable_file, status)
return pywrap_tensorflow.TellFile(self._writable_file)
def __enter__(self):
"""Make usable with "with" statement."""
@ -284,8 +279,7 @@ def file_exists_v2(path):
errors.OpError: Propagates any errors reported by the FileSystem API.
"""
try:
with errors.raise_exception_on_not_ok_status() as status:
pywrap_tensorflow.FileExists(compat.as_bytes(path), status)
pywrap_tensorflow.FileExists(compat.as_bytes(path))
except errors.NotFoundError:
return False
return True
@ -316,8 +310,7 @@ def delete_file_v2(path):
errors.OpError: Propagates any errors reported by the FileSystem API. E.g.,
NotFoundError if the path does not exist.
"""
with errors.raise_exception_on_not_ok_status() as status:
pywrap_tensorflow.DeleteFile(compat.as_bytes(path), status)
pywrap_tensorflow.DeleteFile(compat.as_bytes(path))
def read_file_to_string(filename, binary_mode=False):
@ -385,22 +378,21 @@ def get_matching_files_v2(pattern):
Raises:
errors.OpError: If there are filesystem / directory listing errors.
"""
with errors.raise_exception_on_not_ok_status() as status:
if isinstance(pattern, six.string_types):
return [
# Convert the filenames to string from bytes.
compat.as_str_any(matching_filename)
for matching_filename in pywrap_tensorflow.GetMatchingFiles(
compat.as_bytes(pattern), status)
]
else:
return [
# Convert the filenames to string from bytes.
compat.as_str_any(matching_filename)
for single_filename in pattern
for matching_filename in pywrap_tensorflow.GetMatchingFiles(
compat.as_bytes(single_filename), status)
]
if isinstance(pattern, six.string_types):
return [
# Convert the filenames to string from bytes.
compat.as_str_any(matching_filename)
for matching_filename in pywrap_tensorflow.GetMatchingFiles(
compat.as_bytes(pattern))
]
else:
return [
# Convert the filenames to string from bytes.
compat.as_str_any(matching_filename) # pylint: disable=g-complex-comprehension
for single_filename in pattern
for matching_filename in pywrap_tensorflow.GetMatchingFiles(
compat.as_bytes(single_filename))
]
@tf_export(v1=["gfile.MkDir"])
@ -434,8 +426,7 @@ def create_dir_v2(path):
Raises:
errors.OpError: If the operation fails.
"""
with errors.raise_exception_on_not_ok_status() as status:
pywrap_tensorflow.CreateDir(compat.as_bytes(path), status)
pywrap_tensorflow.CreateDir(compat.as_bytes(path))
@tf_export(v1=["gfile.MakeDirs"])
@ -465,8 +456,7 @@ def recursive_create_dir_v2(path):
Raises:
errors.OpError: If the operation fails.
"""
with errors.raise_exception_on_not_ok_status() as status:
pywrap_tensorflow.RecursivelyCreateDir(compat.as_bytes(path), status)
pywrap_tensorflow.RecursivelyCreateDir(compat.as_bytes(path))
@tf_export(v1=["gfile.Copy"])
@ -498,9 +488,8 @@ def copy_v2(src, dst, overwrite=False):
Raises:
errors.OpError: If the operation fails.
"""
with errors.raise_exception_on_not_ok_status() as status:
pywrap_tensorflow.CopyFile(
compat.as_bytes(src), compat.as_bytes(dst), overwrite, status)
pywrap_tensorflow.CopyFile(
compat.as_bytes(src), compat.as_bytes(dst), overwrite)
@tf_export(v1=["gfile.Rename"])
@ -532,9 +521,8 @@ def rename_v2(src, dst, overwrite=False):
Raises:
errors.OpError: If the operation fails.
"""
with errors.raise_exception_on_not_ok_status() as status:
pywrap_tensorflow.RenameFile(
compat.as_bytes(src), compat.as_bytes(dst), overwrite, status)
pywrap_tensorflow.RenameFile(
compat.as_bytes(src), compat.as_bytes(dst), overwrite)
def atomic_write_string_to_file(filename, contents, overwrite=True):
@ -584,8 +572,7 @@ def delete_recursively_v2(path):
Raises:
errors.OpError: If the operation fails.
"""
with errors.raise_exception_on_not_ok_status() as status:
pywrap_tensorflow.DeleteRecursively(compat.as_bytes(path), status)
pywrap_tensorflow.DeleteRecursively(compat.as_bytes(path))
@tf_export(v1=["gfile.IsDirectory"])
@ -655,14 +642,13 @@ def list_directory_v2(path):
node_def=None,
op=None,
message="Could not find directory {}".format(path))
with errors.raise_exception_on_not_ok_status() as status:
# Convert each element to string, since the return values of the
# vector of string should be interpreted as strings, not bytes.
return [
compat.as_str_any(filename)
for filename in pywrap_tensorflow.GetChildren(
compat.as_bytes(path), status)
]
# Convert each element to string, since the return values of the
# vector of string should be interpreted as strings, not bytes.
return [
compat.as_str_any(filename)
for filename in pywrap_tensorflow.GetChildren(compat.as_bytes(path))
]
@tf_export(v1=["gfile.Walk"])
@ -763,9 +749,8 @@ def stat_v2(path):
errors.OpError: If the operation fails.
"""
file_statistics = pywrap_tensorflow.FileStatistics()
with errors.raise_exception_on_not_ok_status() as status:
pywrap_tensorflow.Stat(compat.as_bytes(path), file_statistics, status)
return file_statistics
pywrap_tensorflow.Stat(compat.as_bytes(path), file_statistics)
return file_statistics
def filecmp(filename_a, filename_b):

View File

@ -22,19 +22,19 @@ limitations under the License.
static PyObject* DoQuantizeTrainingOnGraphDefHelper(
const string& input_graph,
int num_bits,
TF_Status* out_status) {
TF_Status* status) {
string result;
// TODO(suharshs): Make the QuantizeAndDequantizeV2 configurable.
tensorflow::Status status =
tensorflow::Status s =
tensorflow::DoQuantizeTrainingOnSerializedGraphDef(input_graph, num_bits,
"QuantizeAndDequantizeV2", &result);
if (!status.ok()) {
Set_TF_Status_from_Status(out_status, status);
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
Py_RETURN_NONE;
}
PyObject* py_str = PyBytes_FromStringAndSize(result.data(), result.size());
if (!py_str) {
Set_TF_Status_from_Status(out_status,
Set_TF_Status_from_Status(status,
tensorflow::Status(tensorflow::error::INTERNAL,
"Failed to generate serialized string of the rewritten graph."));
Py_RETURN_NONE;
@ -51,7 +51,7 @@ static PyObject* DoQuantizeTrainingOnGraphDefHelper(
PyObject* DoQuantizeTrainingOnGraphDefHelper(
const string& input_graph,
int num_bits,
TF_Status* out_status);
TF_Status* status);
%insert("python") %{
@ -70,10 +70,10 @@ def do_quantize_training_on_graphdef(input_graph, num_bits):
"""
from tensorflow.core.framework.graph_pb2 import GraphDef
from tensorflow.python.framework import errors
with errors.raise_exception_on_not_ok_status() as status:
graph = GraphDef()
result_graph_string = DoQuantizeTrainingOnGraphDefHelper(
input_graph.SerializeToString(), num_bits, status)
graph = GraphDef()
result_graph_string = DoQuantizeTrainingOnGraphDefHelper(
input_graph.SerializeToString(), num_bits)
graph.ParseFromString(result_graph_string)
return graph

View File

@ -104,15 +104,15 @@ limitations under the License.
static PyObject* CheckpointReader_GetTensor(
tensorflow::checkpoint::CheckpointReader* reader,
const string& name,
TF_Status* out_status) {
TF_Status* status) {
PyObject* py_obj = Py_None;
std::unique_ptr<tensorflow::Tensor> tensor;
reader->GetTensor(name, &tensor, out_status);
if (TF_GetCode(out_status) == TF_OK) {
tensorflow::Status status =
reader->GetTensor(name, &tensor, status);
if (TF_GetCode(status) == TF_OK) {
tensorflow::Status s =
tensorflow::ConvertTensorToNdarray(*tensor.get(), &py_obj);
if (!status.ok()) {
Set_TF_Status_from_Status(out_status, status);
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
}
}
return py_obj;
@ -123,7 +123,7 @@ static PyObject* CheckpointReader_GetTensor(
PyObject* CheckpointReader_GetTensor(
tensorflow::checkpoint::CheckpointReader* reader,
const string& name,
TF_Status* out_status);
TF_Status* status);
%ignoreall
@ -150,20 +150,16 @@ PyObject* CheckpointReader_GetTensor(
return self._HasTensor(compat.as_bytes(tensor_str))
def get_tensor(self, tensor_str):
from tensorflow.python.framework import errors
with errors.raise_exception_on_not_ok_status() as status:
from tensorflow.python.util import compat
return CheckpointReader_GetTensor(self, compat.as_bytes(tensor_str),
status)
from tensorflow.python.util import compat
return CheckpointReader_GetTensor(self, compat.as_bytes(tensor_str))
%}
}
%insert("python") %{
def NewCheckpointReader(filepattern):
from tensorflow.python.framework import errors
with errors.raise_exception_on_not_ok_status() as status:
from tensorflow.python.util import compat
return CheckpointReader(compat.as_bytes(filepattern), status)
from tensorflow.python.util import compat
return CheckpointReader(compat.as_bytes(filepattern))
NewCheckpointReader._tf_api_names_v1 = ['train.NewCheckpointReader']
%}