Implementing WriteFile Op (#5330)
* Implementing WriteFile Op * Addressing comments. Output value 0 means success, -1 means failure * minor * Changing the write_file op into a void op with no-return value * Using OP_REQUIRES_OK(WriteStringToFile... instead of TF_CHECK_OK(... * Code style checks using clang-format
This commit is contained in:
parent
7b7c02de56
commit
bf99cf2fdd
@ -119,4 +119,28 @@ class ReadFileOp : public OpKernel {
|
|||||||
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("ReadFile").Device(DEVICE_CPU), ReadFileOp);
|
REGISTER_KERNEL_BUILDER(Name("ReadFile").Device(DEVICE_CPU), ReadFileOp);
|
||||||
|
|
||||||
|
class WriteFileOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
using OpKernel::OpKernel;
|
||||||
|
void Compute(OpKernelContext* context) override {
|
||||||
|
const Tensor* filename_input;
|
||||||
|
const Tensor* contents_input;
|
||||||
|
OP_REQUIRES_OK(context, context->input("filename", &filename_input));
|
||||||
|
OP_REQUIRES_OK(context, context->input("contents", &contents_input));
|
||||||
|
OP_REQUIRES(context, TensorShapeUtils::IsScalar(filename_input->shape()),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Input filename tensor must be scalar, but had shape: ",
|
||||||
|
filename_input->shape().DebugString()));
|
||||||
|
OP_REQUIRES(context, TensorShapeUtils::IsScalar(contents_input->shape()),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Contents tensor must be scalar, but had shape: ",
|
||||||
|
contents_input->shape().DebugString()));
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
context,
|
||||||
|
WriteStringToFile(context->env(), filename_input->scalar<string>()(),
|
||||||
|
contents_input->scalar<string>()()));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("WriteFile").Device(DEVICE_CPU), WriteFileOp);
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -582,6 +582,22 @@ REGISTER_OP("ReadFile")
|
|||||||
Reads and outputs the entire contents of the input filename.
|
Reads and outputs the entire contents of the input filename.
|
||||||
)doc");
|
)doc");
|
||||||
|
|
||||||
|
REGISTER_OP("WriteFile")
|
||||||
|
.Input("filename: string")
|
||||||
|
.Input("contents: string")
|
||||||
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
|
ShapeHandle unused;
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
|
||||||
|
return Status::OK();
|
||||||
|
})
|
||||||
|
.Doc(R"doc(
|
||||||
|
Writes contents to the file at input filename. Creates file if not existing.
|
||||||
|
|
||||||
|
filename: scalar. The name of the file to which we write the contents.
|
||||||
|
contents: scalar. The content to be written to the output file.
|
||||||
|
)doc");
|
||||||
|
|
||||||
REGISTER_OP("MatchingFiles")
|
REGISTER_OP("MatchingFiles")
|
||||||
.Input("pattern: string")
|
.Input("pattern: string")
|
||||||
.Output("filenames: string")
|
.Output("filenames: string")
|
||||||
|
@ -39,6 +39,18 @@ class IoOpsTest(tf.test.TestCase):
|
|||||||
self.assertEqual([], read.get_shape())
|
self.assertEqual([], read.get_shape())
|
||||||
self.assertEqual(read.eval(), contents)
|
self.assertEqual(read.eval(), contents)
|
||||||
|
|
||||||
|
def testWriteFile(self):
|
||||||
|
cases = ['', 'Some contents']
|
||||||
|
for contents in cases:
|
||||||
|
contents = tf.compat.as_bytes(contents)
|
||||||
|
temp = tempfile.NamedTemporaryFile(
|
||||||
|
prefix='WriteFileTest', dir=self.get_temp_dir())
|
||||||
|
with self.test_session() as sess:
|
||||||
|
w = tf.write_file(temp.name, contents)
|
||||||
|
sess.run(w)
|
||||||
|
file_contents = open(temp.name, 'rb').read()
|
||||||
|
self.assertEqual(file_contents, contents)
|
||||||
|
|
||||||
def _subset(self, files, indices):
|
def _subset(self, files, indices):
|
||||||
return set(tf.compat.as_bytes(files[i].name)
|
return set(tf.compat.as_bytes(files[i].name)
|
||||||
for i in range(len(files)) if i in indices)
|
for i in range(len(files)) if i in indices)
|
||||||
|
@ -92,6 +92,7 @@ Queues](../../how_tos/threading_and_queues/index.md).
|
|||||||
|
|
||||||
@@matching_files
|
@@matching_files
|
||||||
@@read_file
|
@@read_file
|
||||||
|
@@write_file
|
||||||
|
|
||||||
## Input pipeline
|
## Input pipeline
|
||||||
|
|
||||||
@ -521,4 +522,5 @@ ops.RegisterShape("ReaderReadUpTo")(common_shapes.call_cpp_shape_fn)
|
|||||||
ops.RegisterShape("ReaderReset")(common_shapes.call_cpp_shape_fn)
|
ops.RegisterShape("ReaderReset")(common_shapes.call_cpp_shape_fn)
|
||||||
ops.RegisterShape("ReaderRestoreState")(common_shapes.call_cpp_shape_fn)
|
ops.RegisterShape("ReaderRestoreState")(common_shapes.call_cpp_shape_fn)
|
||||||
ops.RegisterShape("ReadFile")(common_shapes.call_cpp_shape_fn)
|
ops.RegisterShape("ReadFile")(common_shapes.call_cpp_shape_fn)
|
||||||
|
ops.RegisterShape("WriteFile")(common_shapes.call_cpp_shape_fn)
|
||||||
ops.RegisterShape("MatchingFiles")(common_shapes.call_cpp_shape_fn)
|
ops.RegisterShape("MatchingFiles")(common_shapes.call_cpp_shape_fn)
|
||||||
|
Loading…
Reference in New Issue
Block a user