Implementing WriteFile Op (#5330)
* Implementing WriteFile Op * Addressing comments. Output value 0 means success, -1 means failure * minor * Changing the write_file op into a void op with no-return value * Using OP_REQUIRES_OK(WriteStringToFile... instead of TF_CHECK_OK(... * Code style checks using clang-format
This commit is contained in:
parent
7b7c02de56
commit
bf99cf2fdd
@ -119,4 +119,28 @@ class ReadFileOp : public OpKernel {
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("ReadFile").Device(DEVICE_CPU), ReadFileOp);
|
||||
|
||||
class WriteFileOp : public OpKernel {
|
||||
public:
|
||||
using OpKernel::OpKernel;
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const Tensor* filename_input;
|
||||
const Tensor* contents_input;
|
||||
OP_REQUIRES_OK(context, context->input("filename", &filename_input));
|
||||
OP_REQUIRES_OK(context, context->input("contents", &contents_input));
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsScalar(filename_input->shape()),
|
||||
errors::InvalidArgument(
|
||||
"Input filename tensor must be scalar, but had shape: ",
|
||||
filename_input->shape().DebugString()));
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsScalar(contents_input->shape()),
|
||||
errors::InvalidArgument(
|
||||
"Contents tensor must be scalar, but had shape: ",
|
||||
contents_input->shape().DebugString()));
|
||||
OP_REQUIRES_OK(
|
||||
context,
|
||||
WriteStringToFile(context->env(), filename_input->scalar<string>()(),
|
||||
contents_input->scalar<string>()()));
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("WriteFile").Device(DEVICE_CPU), WriteFileOp);
|
||||
} // namespace tensorflow
|
||||
|
@ -582,6 +582,22 @@ REGISTER_OP("ReadFile")
|
||||
Reads and outputs the entire contents of the input filename.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("WriteFile")
|
||||
.Input("filename: string")
|
||||
.Input("contents: string")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
ShapeHandle unused;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
Writes contents to the file at input filename. Creates file if not existing.
|
||||
|
||||
filename: scalar. The name of the file to which we write the contents.
|
||||
contents: scalar. The content to be written to the output file.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("MatchingFiles")
|
||||
.Input("pattern: string")
|
||||
.Output("filenames: string")
|
||||
|
@ -39,6 +39,18 @@ class IoOpsTest(tf.test.TestCase):
|
||||
self.assertEqual([], read.get_shape())
|
||||
self.assertEqual(read.eval(), contents)
|
||||
|
||||
def testWriteFile(self):
|
||||
cases = ['', 'Some contents']
|
||||
for contents in cases:
|
||||
contents = tf.compat.as_bytes(contents)
|
||||
temp = tempfile.NamedTemporaryFile(
|
||||
prefix='WriteFileTest', dir=self.get_temp_dir())
|
||||
with self.test_session() as sess:
|
||||
w = tf.write_file(temp.name, contents)
|
||||
sess.run(w)
|
||||
file_contents = open(temp.name, 'rb').read()
|
||||
self.assertEqual(file_contents, contents)
|
||||
|
||||
def _subset(self, files, indices):
|
||||
return set(tf.compat.as_bytes(files[i].name)
|
||||
for i in range(len(files)) if i in indices)
|
||||
|
@ -92,6 +92,7 @@ Queues](../../how_tos/threading_and_queues/index.md).
|
||||
|
||||
@@matching_files
|
||||
@@read_file
|
||||
@@write_file
|
||||
|
||||
## Input pipeline
|
||||
|
||||
@ -521,4 +522,5 @@ ops.RegisterShape("ReaderReadUpTo")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("ReaderReset")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("ReaderRestoreState")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("ReadFile")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("WriteFile")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("MatchingFiles")(common_shapes.call_cpp_shape_fn)
|
||||
|
Loading…
Reference in New Issue
Block a user