diff --git a/tensorflow/core/kernels/whole_file_read_ops.cc b/tensorflow/core/kernels/whole_file_read_ops.cc index e3d77b370bb..538e3bbc9eb 100644 --- a/tensorflow/core/kernels/whole_file_read_ops.cc +++ b/tensorflow/core/kernels/whole_file_read_ops.cc @@ -119,4 +119,28 @@ class ReadFileOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("ReadFile").Device(DEVICE_CPU), ReadFileOp); +class WriteFileOp : public OpKernel { + public: + using OpKernel::OpKernel; + void Compute(OpKernelContext* context) override { + const Tensor* filename_input; + const Tensor* contents_input; + OP_REQUIRES_OK(context, context->input("filename", &filename_input)); + OP_REQUIRES_OK(context, context->input("contents", &contents_input)); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(filename_input->shape()), + errors::InvalidArgument( + "Input filename tensor must be scalar, but had shape: ", + filename_input->shape().DebugString())); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(contents_input->shape()), + errors::InvalidArgument( + "Contents tensor must be scalar, but had shape: ", + contents_input->shape().DebugString())); + OP_REQUIRES_OK( + context, + WriteStringToFile(context->env(), filename_input->scalar()(), + contents_input->scalar()())); + } +}; + +REGISTER_KERNEL_BUILDER(Name("WriteFile").Device(DEVICE_CPU), WriteFileOp); } // namespace tensorflow diff --git a/tensorflow/core/ops/io_ops.cc b/tensorflow/core/ops/io_ops.cc index 83f0542e02d..1167461e9e5 100644 --- a/tensorflow/core/ops/io_ops.cc +++ b/tensorflow/core/ops/io_ops.cc @@ -582,6 +582,22 @@ REGISTER_OP("ReadFile") Reads and outputs the entire contents of the input filename. )doc"); +REGISTER_OP("WriteFile") + .Input("filename: string") + .Input("contents: string") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + return Status::OK(); + }) + .Doc(R"doc( +Writes contents to the file at input filename. Creates file if not existing. + +filename: scalar. The name of the file to which we write the contents. +contents: scalar. The content to be written to the output file. +)doc"); + REGISTER_OP("MatchingFiles") .Input("pattern: string") .Output("filenames: string") diff --git a/tensorflow/python/kernel_tests/io_ops_test.py b/tensorflow/python/kernel_tests/io_ops_test.py index 9e85fe2b97a..d484a609fce 100644 --- a/tensorflow/python/kernel_tests/io_ops_test.py +++ b/tensorflow/python/kernel_tests/io_ops_test.py @@ -39,6 +39,18 @@ class IoOpsTest(tf.test.TestCase): self.assertEqual([], read.get_shape()) self.assertEqual(read.eval(), contents) + def testWriteFile(self): + cases = ['', 'Some contents'] + for contents in cases: + contents = tf.compat.as_bytes(contents) + temp = tempfile.NamedTemporaryFile( + prefix='WriteFileTest', dir=self.get_temp_dir()) + with self.test_session() as sess: + w = tf.write_file(temp.name, contents) + sess.run(w) + file_contents = open(temp.name, 'rb').read() + self.assertEqual(file_contents, contents) + def _subset(self, files, indices): return set(tf.compat.as_bytes(files[i].name) for i in range(len(files)) if i in indices) diff --git a/tensorflow/python/ops/io_ops.py b/tensorflow/python/ops/io_ops.py index 15605ee42a4..7daf7c8cc82 100644 --- a/tensorflow/python/ops/io_ops.py +++ b/tensorflow/python/ops/io_ops.py @@ -92,6 +92,7 @@ Queues](../../how_tos/threading_and_queues/index.md). @@matching_files @@read_file +@@write_file ## Input pipeline @@ -521,4 +522,5 @@ ops.RegisterShape("ReaderReadUpTo")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("ReaderReset")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("ReaderRestoreState")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("ReadFile")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("WriteFile")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("MatchingFiles")(common_shapes.call_cpp_shape_fn)