Converted py_record_writer.i and py_record_reader.i to pybind11
This is part of a larger effort to deprecate swig and eventually with modularization break pywrap_tensorflow into smaller components. Please refer to https://github.com/tensorflow/community/blob/master/rfcs/20190208-pybind11.md for more information. PiperOrigin-RevId: 286188001 Change-Id: Iebdc4335de88bc9e42267e68e28e0d7a9e840b9f
This commit is contained in:
parent
77cb370373
commit
54f1de4fcf
@ -853,18 +853,6 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "py_record_writer_lib",
|
||||
srcs = ["lib/io/py_record_writer.cc"],
|
||||
hdrs = ["lib/io/py_record_writer.h"],
|
||||
deps = [
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_shared_object(
|
||||
name = "framework/test_file_system.so",
|
||||
srcs = ["framework/test_file_system.cc"],
|
||||
@ -5492,7 +5480,6 @@ tf_py_wrap_cc(
|
||||
"grappler/tf_optimizer.i",
|
||||
"lib/core/strings.i",
|
||||
"lib/io/py_record_reader.i",
|
||||
"lib/io/py_record_writer.i",
|
||||
"platform/base.i",
|
||||
"//tensorflow/compiler/mlir/python:mlir.i",
|
||||
],
|
||||
@ -5512,7 +5499,6 @@ tf_py_wrap_cc(
|
||||
":py_exception_registry",
|
||||
":py_func_lib",
|
||||
":py_record_reader_lib",
|
||||
":py_record_writer_lib",
|
||||
":python_op_gen",
|
||||
":tf_session_helper",
|
||||
"//third_party/python_runtime:headers",
|
||||
@ -5702,6 +5688,7 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":_pywrap_file_io",
|
||||
":_pywrap_record_io",
|
||||
":errors",
|
||||
":pywrap_tensorflow",
|
||||
":util",
|
||||
@ -5709,6 +5696,21 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_python_pybind_extension(
|
||||
name = "_pywrap_record_io",
|
||||
srcs = ["lib/io/record_io_wrapper.cc"],
|
||||
module_name = "_pywrap_record_io",
|
||||
deps = [
|
||||
":pybind11_absl",
|
||||
":pybind11_status",
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/platform:types",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "session",
|
||||
srcs = ["client/session.py"],
|
||||
|
@ -22,9 +22,10 @@ import glob
|
||||
import os
|
||||
import threading
|
||||
|
||||
from six.moves import map
|
||||
|
||||
from tensorflow.core.protobuf import debug_event_pb2
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.lib.io import tf_record
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
|
||||
@ -69,19 +70,9 @@ class DebugEventsReader(object):
|
||||
if file_path not in self._readers: # 1st check, without lock.
|
||||
with self._readers_lock:
|
||||
if file_path not in self._readers: # 2nd check, with lock.
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
self._readers[file_path] = pywrap_tensorflow.PyRecordReader_New(
|
||||
compat.as_bytes(file_path), 0, b"", status)
|
||||
reader = self._readers[file_path]
|
||||
while True:
|
||||
try:
|
||||
reader.GetNext()
|
||||
except (errors.DataLossError, errors.OutOfRangeError):
|
||||
# We ignore partial read exceptions, because a record may be truncated.
|
||||
# PyRecordReader holds the offset prior to the failed read, so retrying
|
||||
# will succeed.
|
||||
break
|
||||
yield debug_event_pb2.DebugEvent.FromString(reader.record())
|
||||
self._readers[file_path] = tf_record.tf_record_iterator(file_path)
|
||||
|
||||
return map(debug_event_pb2.DebugEvent.FromString, self._readers[file_path])
|
||||
|
||||
def metadata_iterator(self):
|
||||
return self._generic_iterator(self._metadata_path)
|
||||
@ -102,5 +93,5 @@ class DebugEventsReader(object):
|
||||
return self._generic_iterator(self._graph_execution_traces_path)
|
||||
|
||||
def close(self):
|
||||
for reader in self._readers.values():
|
||||
reader.Close()
|
||||
with self._readers_lock:
|
||||
self._readers.clear()
|
||||
|
@ -1,102 +0,0 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/python/lib/io/py_record_writer.h"
|
||||
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/lib/io/record_writer.h"
|
||||
#include "tensorflow/core/lib/io/zlib_compression_options.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace io {
|
||||
|
||||
PyRecordWriter::PyRecordWriter() {}
|
||||
|
||||
PyRecordWriter* PyRecordWriter::New(const string& filename,
|
||||
const io::RecordWriterOptions& options,
|
||||
TF_Status* out_status) {
|
||||
std::unique_ptr<WritableFile> file;
|
||||
Status s = Env::Default()->NewWritableFile(filename, &file);
|
||||
if (!s.ok()) {
|
||||
Set_TF_Status_from_Status(out_status, s);
|
||||
return nullptr;
|
||||
}
|
||||
PyRecordWriter* writer = new PyRecordWriter;
|
||||
writer->file_ = std::move(file);
|
||||
writer->writer_.reset(new RecordWriter(writer->file_.get(), options));
|
||||
return writer;
|
||||
}
|
||||
|
||||
PyRecordWriter::~PyRecordWriter() {
|
||||
// Writer depends on file during close for zlib flush, so destruct first.
|
||||
writer_.reset();
|
||||
file_.reset();
|
||||
}
|
||||
|
||||
void PyRecordWriter::WriteRecord(tensorflow::StringPiece record,
|
||||
TF_Status* out_status) {
|
||||
if (writer_ == nullptr) {
|
||||
TF_SetStatus(out_status, TF_FAILED_PRECONDITION,
|
||||
"Writer not initialized or previously closed");
|
||||
return;
|
||||
}
|
||||
Status s = writer_->WriteRecord(record);
|
||||
if (!s.ok()) {
|
||||
Set_TF_Status_from_Status(out_status, s);
|
||||
}
|
||||
}
|
||||
|
||||
void PyRecordWriter::Flush(TF_Status* out_status) {
|
||||
if (writer_ == nullptr) {
|
||||
TF_SetStatus(out_status, TF_FAILED_PRECONDITION,
|
||||
"Writer not initialized or previously closed");
|
||||
return;
|
||||
}
|
||||
Status s = writer_->Flush();
|
||||
if (s.ok()) {
|
||||
// Per the RecordWriter contract, flushing the RecordWriter does not
|
||||
// flush the underlying file. Here we need to do both.
|
||||
s = file_->Flush();
|
||||
}
|
||||
if (!s.ok()) {
|
||||
Set_TF_Status_from_Status(out_status, s);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
void PyRecordWriter::Close(TF_Status* out_status) {
|
||||
if (writer_ != nullptr) {
|
||||
Status s = writer_->Close();
|
||||
if (!s.ok()) {
|
||||
Set_TF_Status_from_Status(out_status, s);
|
||||
return;
|
||||
}
|
||||
writer_.reset(nullptr);
|
||||
}
|
||||
if (file_ != nullptr) {
|
||||
Status s = file_->Close();
|
||||
if (!s.ok()) {
|
||||
Set_TF_Status_from_Status(out_status, s);
|
||||
return;
|
||||
}
|
||||
file_.reset(nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace io
|
||||
} // namespace tensorflow
|
@ -1,60 +0,0 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_PYTHON_LIB_IO_PY_RECORD_WRITER_H_
|
||||
#define TENSORFLOW_PYTHON_LIB_IO_PY_RECORD_WRITER_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/lib/io/record_writer.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class WritableFile;
|
||||
|
||||
namespace io {
|
||||
|
||||
class RecordWriter;
|
||||
|
||||
// A wrapper around io::RecordWriter that is more easily SWIG wrapped for
|
||||
// Python. An instance of this class is not safe for concurrent access
|
||||
// by multiple threads.
|
||||
class PyRecordWriter {
|
||||
public:
|
||||
static PyRecordWriter* New(const string& filename,
|
||||
const io::RecordWriterOptions& compression_options,
|
||||
TF_Status* out_status);
|
||||
~PyRecordWriter();
|
||||
|
||||
void WriteRecord(tensorflow::StringPiece record, TF_Status* out_status);
|
||||
void Flush(TF_Status* out_status);
|
||||
void Close(TF_Status* out_status);
|
||||
|
||||
private:
|
||||
PyRecordWriter();
|
||||
|
||||
std::unique_ptr<io::RecordWriter> writer_;
|
||||
std::unique_ptr<WritableFile> file_;
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(PyRecordWriter);
|
||||
};
|
||||
|
||||
} // namespace io
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_PYTHON_LIB_IO_PY_RECORD_WRITER_H_
|
@ -1,76 +0,0 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
%nothread tensorflow::io::PyRecordWriter::WriteRecord;
|
||||
|
||||
%include "tensorflow/python/platform/base.i"
|
||||
%include "tensorflow/python/lib/core/strings.i"
|
||||
|
||||
// Define int8_t explicitly instead of including "stdint.i", since "stdint.h"
|
||||
// and "stdint.i" disagree on the definition of int64_t.
|
||||
typedef signed char int8;
|
||||
%{ typedef signed char int8; %}
|
||||
|
||||
%feature("except") tensorflow::io::PyRecordWriter::New {
|
||||
// Let other threads run while we write
|
||||
Py_BEGIN_ALLOW_THREADS
|
||||
$action
|
||||
Py_END_ALLOW_THREADS
|
||||
}
|
||||
|
||||
%newobject tensorflow::io::PyRecordWriter::New;
|
||||
%newobject tensorflow::io::RecordWriterOptions::CreateRecordWriterOptions;
|
||||
|
||||
%feature("except") tensorflow::io::PyRecordWriter::WriteRecord {
|
||||
// Let other threads run while we write
|
||||
Py_BEGIN_ALLOW_THREADS
|
||||
$action
|
||||
Py_END_ALLOW_THREADS
|
||||
}
|
||||
|
||||
%{
|
||||
#include "tensorflow/core/lib/io/record_writer.h"
|
||||
#include "tensorflow/core/lib/io/zlib_compression_options.h"
|
||||
#include "tensorflow/python/lib/io/py_record_writer.h"
|
||||
%}
|
||||
|
||||
%ignoreall
|
||||
|
||||
%unignore tensorflow;
|
||||
%unignore tensorflow::io;
|
||||
%unignore tensorflow::io::PyRecordWriter;
|
||||
%unignore tensorflow::io::PyRecordWriter::~PyRecordWriter;
|
||||
%unignore tensorflow::io::PyRecordWriter::WriteRecord;
|
||||
%unignore tensorflow::io::PyRecordWriter::Flush;
|
||||
%unignore tensorflow::io::PyRecordWriter::Close;
|
||||
%unignore tensorflow::io::PyRecordWriter::New;
|
||||
%unignore tensorflow::io::ZlibCompressionOptions;
|
||||
%unignore tensorflow::io::ZlibCompressionOptions::flush_mode;
|
||||
%unignore tensorflow::io::ZlibCompressionOptions::input_buffer_size;
|
||||
%unignore tensorflow::io::ZlibCompressionOptions::output_buffer_size;
|
||||
%unignore tensorflow::io::ZlibCompressionOptions::window_bits;
|
||||
%unignore tensorflow::io::ZlibCompressionOptions::compression_level;
|
||||
%unignore tensorflow::io::ZlibCompressionOptions::compression_method;
|
||||
%unignore tensorflow::io::ZlibCompressionOptions::mem_level;
|
||||
%unignore tensorflow::io::ZlibCompressionOptions::compression_strategy;
|
||||
%unignore tensorflow::io::RecordWriterOptions;
|
||||
%unignore tensorflow::io::RecordWriterOptions::CreateRecordWriterOptions;
|
||||
%unignore tensorflow::io::RecordWriterOptions::zlib_options;
|
||||
|
||||
%include "tensorflow/core/lib/io/record_writer.h"
|
||||
%include "tensorflow/core/lib/io/zlib_compression_options.h"
|
||||
%include "tensorflow/python/lib/io/py_record_writer.h"
|
||||
|
||||
%unignoreall
|
254
tensorflow/python/lib/io/record_io_wrapper.cc
Normal file
254
tensorflow/python/lib/io/record_io_wrapper.cc
Normal file
@ -0,0 +1,254 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "include/pybind11/pybind11.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/lib/io/record_reader.h"
|
||||
#include "tensorflow/core/lib/io/record_writer.h"
|
||||
#include "tensorflow/core/lib/io/zlib_compression_options.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/file_system.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/python/lib/core/pybind11_absl.h"
|
||||
#include "tensorflow/python/lib/core/pybind11_status.h"
|
||||
|
||||
namespace {
|
||||
|
||||
namespace py = ::pybind11;
|
||||
|
||||
class PyRecordReader {
|
||||
public:
|
||||
// NOTE(sethtroisi): At this time PyRecordReader doesn't benefit from taking
|
||||
// RecordReaderOptions, if this changes the API can be updated at that time.
|
||||
static tensorflow::Status New(const std::string& filename,
|
||||
const std::string& compression_type,
|
||||
PyRecordReader** out) {
|
||||
std::unique_ptr<tensorflow::RandomAccessFile> file;
|
||||
TF_RETURN_IF_ERROR(
|
||||
tensorflow::Env::Default()->NewRandomAccessFile(filename, &file));
|
||||
auto options =
|
||||
tensorflow::io::RecordReaderOptions::CreateRecordReaderOptions(
|
||||
compression_type);
|
||||
options.buffer_size = kReaderBufferSize;
|
||||
auto reader =
|
||||
absl::make_unique<tensorflow::io::RecordReader>(file.get(), options);
|
||||
*out = new PyRecordReader(std::move(file), std::move(reader));
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
PyRecordReader() = delete;
|
||||
~PyRecordReader() { Close(); }
|
||||
|
||||
tensorflow::Status ReadNextRecord(tensorflow::tstring* out) {
|
||||
if (IsClosed()) {
|
||||
return tensorflow::errors::FailedPrecondition("Reader is closed.");
|
||||
}
|
||||
|
||||
return reader_->ReadRecord(&offset_, out);
|
||||
}
|
||||
|
||||
bool IsClosed() const { return file_ == nullptr && reader_ == nullptr; }
|
||||
|
||||
void Close() {
|
||||
reader_ = nullptr;
|
||||
file_ = nullptr;
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr tensorflow::uint64 kReaderBufferSize = 16 * 1024 * 1024;
|
||||
|
||||
PyRecordReader(std::unique_ptr<tensorflow::RandomAccessFile> file,
|
||||
std::unique_ptr<tensorflow::io::RecordReader> reader)
|
||||
: file_(std::move(file)), reader_(std::move(reader)) {
|
||||
offset_ = 0;
|
||||
}
|
||||
|
||||
tensorflow::uint64 offset_;
|
||||
std::unique_ptr<tensorflow::RandomAccessFile> file_;
|
||||
std::unique_ptr<tensorflow::io::RecordReader> reader_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(PyRecordReader);
|
||||
};
|
||||
|
||||
class PyRecordWriter {
|
||||
public:
|
||||
static tensorflow::Status New(
|
||||
const std::string& filename,
|
||||
const tensorflow::io::RecordWriterOptions& options,
|
||||
PyRecordWriter** out) {
|
||||
std::unique_ptr<tensorflow::WritableFile> file;
|
||||
TF_RETURN_IF_ERROR(
|
||||
tensorflow::Env::Default()->NewWritableFile(filename, &file));
|
||||
auto writer =
|
||||
absl::make_unique<tensorflow::io::RecordWriter>(file.get(), options);
|
||||
*out = new PyRecordWriter(std::move(file), std::move(writer));
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
PyRecordWriter() = delete;
|
||||
~PyRecordWriter() { Close(); }
|
||||
|
||||
tensorflow::Status WriteRecord(tensorflow::StringPiece record) {
|
||||
if (IsClosed()) {
|
||||
return tensorflow::errors::FailedPrecondition("Writer is closed.");
|
||||
}
|
||||
return writer_->WriteRecord(record);
|
||||
}
|
||||
|
||||
tensorflow::Status Flush() {
|
||||
if (IsClosed()) {
|
||||
return tensorflow::errors::FailedPrecondition("Writer is closed.");
|
||||
}
|
||||
|
||||
auto status = writer_->Flush();
|
||||
if (status.ok()) {
|
||||
// Per the RecordWriter contract, flushing the RecordWriter does not
|
||||
// flush the underlying file. Here we need to do both.
|
||||
return file_->Flush();
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
bool IsClosed() const { return file_ == nullptr && writer_ == nullptr; }
|
||||
|
||||
tensorflow::Status Close() {
|
||||
if (writer_ != nullptr) {
|
||||
auto status = writer_->Close();
|
||||
writer_ = nullptr;
|
||||
if (!status.ok()) return status;
|
||||
}
|
||||
if (file_ != nullptr) {
|
||||
auto status = file_->Close();
|
||||
file_ = nullptr;
|
||||
if (!status.ok()) return status;
|
||||
}
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
PyRecordWriter(std::unique_ptr<tensorflow::WritableFile> file,
|
||||
std::unique_ptr<tensorflow::io::RecordWriter> writer)
|
||||
: file_(std::move(file)), writer_(std::move(writer)) {}
|
||||
|
||||
std::unique_ptr<tensorflow::WritableFile> file_;
|
||||
std::unique_ptr<tensorflow::io::RecordWriter> writer_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(PyRecordWriter);
|
||||
};
|
||||
|
||||
PYBIND11_MODULE(_pywrap_record_io, m) {
|
||||
py::class_<PyRecordReader>(m, "RecordIterator")
|
||||
.def(py::init(
|
||||
[](const std::string& filename, const std::string& compression_type) {
|
||||
tensorflow::Status status;
|
||||
PyRecordReader* self = nullptr;
|
||||
{
|
||||
py::gil_scoped_release release;
|
||||
status = PyRecordReader::New(filename, compression_type, &self);
|
||||
}
|
||||
MaybeRaiseRegisteredFromStatus(status);
|
||||
return self;
|
||||
}))
|
||||
.def("__iter__", [](const py::object& self) { return self; })
|
||||
.def("__next__",
|
||||
[](PyRecordReader* self) {
|
||||
if (self->IsClosed()) {
|
||||
throw py::stop_iteration();
|
||||
}
|
||||
|
||||
tensorflow::tstring record;
|
||||
tensorflow::Status status;
|
||||
{
|
||||
py::gil_scoped_release release;
|
||||
status = self->ReadNextRecord(&record);
|
||||
}
|
||||
if (tensorflow::errors::IsOutOfRange(status)) {
|
||||
// Don't close because the file being read could be updated
|
||||
// in-between
|
||||
// __next__ calls.
|
||||
throw py::stop_iteration();
|
||||
}
|
||||
MaybeRaiseRegisteredFromStatus(status);
|
||||
return py::bytes(record);
|
||||
})
|
||||
.def("close", [](PyRecordReader* self) { self->Close(); });
|
||||
|
||||
using tensorflow::io::ZlibCompressionOptions;
|
||||
py::class_<ZlibCompressionOptions>(m, "ZlibCompressionOptions")
|
||||
.def_readwrite("flush_mode", &ZlibCompressionOptions::flush_mode)
|
||||
.def_readwrite("input_buffer_size",
|
||||
&ZlibCompressionOptions::input_buffer_size)
|
||||
.def_readwrite("output_buffer_size",
|
||||
&ZlibCompressionOptions::output_buffer_size)
|
||||
.def_readwrite("window_bits", &ZlibCompressionOptions::window_bits)
|
||||
.def_readwrite("compression_level",
|
||||
&ZlibCompressionOptions::compression_level)
|
||||
.def_readwrite("compression_method",
|
||||
&ZlibCompressionOptions::compression_method)
|
||||
.def_readwrite("mem_level", &ZlibCompressionOptions::mem_level)
|
||||
.def_readwrite("compression_strategy",
|
||||
&ZlibCompressionOptions::compression_strategy);
|
||||
|
||||
using tensorflow::io::RecordWriterOptions;
|
||||
py::class_<RecordWriterOptions>(m, "RecordWriterOptions")
|
||||
.def(py::init(&RecordWriterOptions::CreateRecordWriterOptions))
|
||||
.def_readonly("compression_type", &RecordWriterOptions::compression_type)
|
||||
.def_readonly("zlib_options", &RecordWriterOptions::zlib_options);
|
||||
|
||||
using tensorflow::MaybeRaiseRegisteredFromStatus;
|
||||
|
||||
py::class_<PyRecordWriter>(m, "RecordWriter")
|
||||
.def(py::init(
|
||||
[](const std::string& filename, const RecordWriterOptions& options) {
|
||||
PyRecordWriter* self = nullptr;
|
||||
tensorflow::Status status;
|
||||
{
|
||||
py::gil_scoped_release release;
|
||||
status = PyRecordWriter::New(filename, options, &self);
|
||||
}
|
||||
MaybeRaiseRegisteredFromStatus(status);
|
||||
return self;
|
||||
}))
|
||||
.def("__enter__", [](const py::object& self) { return self; })
|
||||
.def("__exit__",
|
||||
[](PyRecordWriter* self, py::args) {
|
||||
MaybeRaiseRegisteredFromStatus(self->Close());
|
||||
})
|
||||
.def(
|
||||
"write",
|
||||
[](PyRecordWriter* self, tensorflow::StringPiece record) {
|
||||
tensorflow::Status status;
|
||||
{
|
||||
py::gil_scoped_release release;
|
||||
status = self->WriteRecord(record);
|
||||
}
|
||||
MaybeRaiseRegisteredFromStatus(status);
|
||||
},
|
||||
py::arg("record"))
|
||||
.def("flush",
|
||||
[](PyRecordWriter* self) {
|
||||
MaybeRaiseRegisteredFromStatus(self->Flush());
|
||||
})
|
||||
.def("close", [](PyRecordWriter* self) {
|
||||
MaybeRaiseRegisteredFromStatus(self->Close());
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
@ -19,8 +19,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python import _pywrap_record_io
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
@ -127,7 +126,7 @@ class TFRecordOptions(object):
|
||||
|
||||
def _as_record_writer_options(self):
|
||||
"""Convert to RecordWriterOptions for use with PyRecordWriter."""
|
||||
options = pywrap_tensorflow.RecordWriterOptions_CreateRecordWriterOptions(
|
||||
options = _pywrap_record_io.RecordWriterOptions(
|
||||
compat.as_bytes(
|
||||
self.get_compression_type_string(self.compression_type)))
|
||||
|
||||
@ -162,34 +161,20 @@ def tf_record_iterator(path, options=None):
|
||||
path: The path to the TFRecords file.
|
||||
options: (optional) A TFRecordOptions object.
|
||||
|
||||
Yields:
|
||||
Strings.
|
||||
Returns:
|
||||
An iterator of serialized TFRecords.
|
||||
|
||||
Raises:
|
||||
IOError: If `path` cannot be opened for reading.
|
||||
"""
|
||||
compression_type = TFRecordOptions.get_compression_type_string(options)
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
reader = pywrap_tensorflow.PyRecordReader_New(
|
||||
compat.as_bytes(path), 0, compat.as_bytes(compression_type), status)
|
||||
|
||||
if reader is None:
|
||||
raise IOError("Could not open %s." % path)
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
reader.GetNext()
|
||||
except errors.OutOfRangeError:
|
||||
break
|
||||
yield reader.record()
|
||||
finally:
|
||||
reader.Close()
|
||||
return _pywrap_record_io.RecordIterator(path, compression_type)
|
||||
|
||||
|
||||
@tf_export(
|
||||
"io.TFRecordWriter", v1=["io.TFRecordWriter", "python_io.TFRecordWriter"])
|
||||
@deprecation.deprecated_endpoints("python_io.TFRecordWriter")
|
||||
class TFRecordWriter(object):
|
||||
class TFRecordWriter(_pywrap_record_io.RecordWriter):
|
||||
"""A class to write records to a TFRecords file.
|
||||
|
||||
[TFRecords tutorial](https://www.tensorflow.org/tutorials/load_data/tfrecord)
|
||||
@ -268,35 +253,29 @@ class TFRecordWriter(object):
|
||||
if not isinstance(options, TFRecordOptions):
|
||||
options = TFRecordOptions(compression_type=options)
|
||||
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
# pylint: disable=protected-access
|
||||
self._writer = pywrap_tensorflow.PyRecordWriter_New(
|
||||
compat.as_bytes(path), options._as_record_writer_options(), status)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
def __enter__(self):
|
||||
"""Enter a `with` block."""
|
||||
return self
|
||||
|
||||
def __exit__(self, unused_type, unused_value, unused_traceback):
|
||||
"""Exit a `with` block, closing the file."""
|
||||
self.close()
|
||||
# pylint: disable=protected-access
|
||||
super(TFRecordWriter, self).__init__(
|
||||
compat.as_bytes(path), options._as_record_writer_options())
|
||||
# pylint: enable=protected-access
|
||||
|
||||
# TODO(slebedev): The following wrapper methods are there to compensate
|
||||
# for lack of signatures in pybind11-generated classes. Switch to
|
||||
# __text_signature__ when TensorFlow drops Python 2.X support.
|
||||
# See https://github.com/pybind/pybind11/issues/945
|
||||
# pylint: disable=useless-super-delegation
|
||||
def write(self, record):
|
||||
"""Write a string record to the file.
|
||||
|
||||
Args:
|
||||
record: str
|
||||
"""
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
self._writer.WriteRecord(record, status)
|
||||
super(TFRecordWriter, self).write(record)
|
||||
|
||||
def flush(self):
|
||||
"""Flush the file."""
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
self._writer.Flush(status)
|
||||
super(TFRecordWriter, self).flush()
|
||||
|
||||
def close(self):
|
||||
"""Close the file."""
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
self._writer.Close(status)
|
||||
super(TFRecordWriter, self).close()
|
||||
# pylint: enable=useless-super-delegation
|
||||
|
@ -20,7 +20,6 @@ limitations under the License.
|
||||
%include "tensorflow/python/client/tf_session.i"
|
||||
|
||||
%include "tensorflow/python/lib/io/py_record_reader.i"
|
||||
%include "tensorflow/python/lib/io/py_record_writer.i"
|
||||
|
||||
%include "tensorflow/python/grappler/cluster.i"
|
||||
%include "tensorflow/python/grappler/item.i"
|
||||
|
@ -1,7 +1,8 @@
|
||||
path: "tensorflow.io.TFRecordWriter"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.lib.io.tf_record.TFRecordWriter\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
is_instance: "<class \'tensorflow.python._pywrap_record_io.RecordWriter\'>"
|
||||
is_instance: "<class \'pybind11_builtins.pybind11_object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'path\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -42,7 +42,7 @@ tf_module {
|
||||
}
|
||||
member {
|
||||
name: "TFRecordWriter"
|
||||
mtype: "<type \'type\'>"
|
||||
mtype: "<class \'pybind11_builtins.pybind11_type\'>"
|
||||
}
|
||||
member {
|
||||
name: "VarLenFeature"
|
||||
|
@ -1,7 +1,8 @@
|
||||
path: "tensorflow.python_io.TFRecordWriter"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.lib.io.tf_record.TFRecordWriter\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
is_instance: "<class \'tensorflow.python._pywrap_record_io.RecordWriter\'>"
|
||||
is_instance: "<class \'pybind11_builtins.pybind11_object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'path\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -10,7 +10,7 @@ tf_module {
|
||||
}
|
||||
member {
|
||||
name: "TFRecordWriter"
|
||||
mtype: "<type \'type\'>"
|
||||
mtype: "<class \'pybind11_builtins.pybind11_type\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "tf_record_iterator"
|
||||
|
@ -1,7 +1,8 @@
|
||||
path: "tensorflow.io.TFRecordWriter"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.lib.io.tf_record.TFRecordWriter\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
is_instance: "<class \'tensorflow.python._pywrap_record_io.RecordWriter\'>"
|
||||
is_instance: "<class \'pybind11_builtins.pybind11_object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'path\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -22,7 +22,7 @@ tf_module {
|
||||
}
|
||||
member {
|
||||
name: "TFRecordWriter"
|
||||
mtype: "<type \'type\'>"
|
||||
mtype: "<class \'pybind11_builtins.pybind11_type\'>"
|
||||
}
|
||||
member {
|
||||
name: "VarLenFeature"
|
||||
|
Loading…
Reference in New Issue
Block a user