Implementing WriteFile Op (#5330)

* Implementing WriteFile Op

* Addressing comments. Output value 0 means success, -1 means failure

* minor

* Changing the write_file op into a void op with no-return value

* Using OP_REQUIRES_OK(WriteStringToFile... instead of TF_CHECK_OK(...

* Code style checks using clang-format
This commit is contained in:
Moustafa Alzantot 2016-11-03 09:44:27 -07:00 committed by Vijay Vasudevan
parent 7b7c02de56
commit bf99cf2fdd
4 changed files with 54 additions and 0 deletions

View File

@ -119,4 +119,28 @@ class ReadFileOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("ReadFile").Device(DEVICE_CPU), ReadFileOp); REGISTER_KERNEL_BUILDER(Name("ReadFile").Device(DEVICE_CPU), ReadFileOp);
class WriteFileOp : public OpKernel {
public:
using OpKernel::OpKernel;
void Compute(OpKernelContext* context) override {
const Tensor* filename_input;
const Tensor* contents_input;
OP_REQUIRES_OK(context, context->input("filename", &filename_input));
OP_REQUIRES_OK(context, context->input("contents", &contents_input));
OP_REQUIRES(context, TensorShapeUtils::IsScalar(filename_input->shape()),
errors::InvalidArgument(
"Input filename tensor must be scalar, but had shape: ",
filename_input->shape().DebugString()));
OP_REQUIRES(context, TensorShapeUtils::IsScalar(contents_input->shape()),
errors::InvalidArgument(
"Contents tensor must be scalar, but had shape: ",
contents_input->shape().DebugString()));
OP_REQUIRES_OK(
context,
WriteStringToFile(context->env(), filename_input->scalar<string>()(),
contents_input->scalar<string>()()));
}
};
REGISTER_KERNEL_BUILDER(Name("WriteFile").Device(DEVICE_CPU), WriteFileOp);
} // namespace tensorflow } // namespace tensorflow

View File

@ -582,6 +582,22 @@ REGISTER_OP("ReadFile")
Reads and outputs the entire contents of the input filename. Reads and outputs the entire contents of the input filename.
)doc"); )doc");
REGISTER_OP("WriteFile")
.Input("filename: string")
.Input("contents: string")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
return Status::OK();
})
.Doc(R"doc(
Writes contents to the file at input filename. Creates file if not existing.
filename: scalar. The name of the file to which we write the contents.
contents: scalar. The content to be written to the output file.
)doc");
REGISTER_OP("MatchingFiles") REGISTER_OP("MatchingFiles")
.Input("pattern: string") .Input("pattern: string")
.Output("filenames: string") .Output("filenames: string")

View File

@ -39,6 +39,18 @@ class IoOpsTest(tf.test.TestCase):
self.assertEqual([], read.get_shape()) self.assertEqual([], read.get_shape())
self.assertEqual(read.eval(), contents) self.assertEqual(read.eval(), contents)
def testWriteFile(self):
cases = ['', 'Some contents']
for contents in cases:
contents = tf.compat.as_bytes(contents)
temp = tempfile.NamedTemporaryFile(
prefix='WriteFileTest', dir=self.get_temp_dir())
with self.test_session() as sess:
w = tf.write_file(temp.name, contents)
sess.run(w)
file_contents = open(temp.name, 'rb').read()
self.assertEqual(file_contents, contents)
def _subset(self, files, indices): def _subset(self, files, indices):
return set(tf.compat.as_bytes(files[i].name) return set(tf.compat.as_bytes(files[i].name)
for i in range(len(files)) if i in indices) for i in range(len(files)) if i in indices)

View File

@ -92,6 +92,7 @@ Queues](../../how_tos/threading_and_queues/index.md).
@@matching_files @@matching_files
@@read_file @@read_file
@@write_file
## Input pipeline ## Input pipeline
@ -521,4 +522,5 @@ ops.RegisterShape("ReaderReadUpTo")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("ReaderReset")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("ReaderReset")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("ReaderRestoreState")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("ReaderRestoreState")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("ReadFile")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("ReadFile")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("WriteFile")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("MatchingFiles")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("MatchingFiles")(common_shapes.call_cpp_shape_fn)