Added experimental C APIs to build a stack of dataset + iterator nodes that
reads imagenet TFRecord files. PiperOrigin-RevId: 190488817
This commit is contained in:
parent
cc6b2ae837
commit
be917027e3
@ -220,6 +220,7 @@ tf_cc_test(
|
||||
name = "c_api_experimental_test",
|
||||
size = "small",
|
||||
srcs = ["c_api_experimental_test.cc"],
|
||||
data = ["testdata/tf_record"],
|
||||
linkopts = select({
|
||||
"//tensorflow:darwin": ["-headerpad_max_install_names"],
|
||||
"//conditions:default": [],
|
||||
@ -230,6 +231,7 @@ tf_cc_test(
|
||||
deps = [
|
||||
":c_api_experimental",
|
||||
":c_test_util",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -87,25 +87,22 @@ TF_CAPI_EXPORT extern const char* TF_GraphDebugString(TF_Graph* graph,
|
||||
TF_CAPI_EXPORT extern const char* TF_GraphDebugString(TF_Graph* graph,
|
||||
size_t* len);
|
||||
|
||||
// Creates a stack of data set + iterator nodes reading the TFRecord files from
|
||||
// `file_path`, and outputs the following info on success:
|
||||
// Creates a stack of data set + iterator nodes, currently hard-coded to return
|
||||
// a sequence of 3 float values <42.0, 43.0, 44.0> over 3 calls. On success,
|
||||
// returns the IteratorGetNext node, which caller can run or feed into an node.
|
||||
//
|
||||
// 1. Returns the IteratorGetNext node, which caller can run or feed into an
|
||||
// node.
|
||||
//
|
||||
// 2. Sets `dataset_func` to the created function that encapsulates the data set
|
||||
// nodes. Caller owns that function, and must call TF_DeleteFunction() on it.
|
||||
//
|
||||
//
|
||||
// The nodes are currently hard-coded to return a single Int32 of value 1.
|
||||
// TODO(hongm): Extend the API to allow customization of the nodes created.
|
||||
TF_CAPI_EXPORT extern TF_Operation* TF_MakeIteratorGetNextWithDatasets(
|
||||
TF_Graph* graph, const char* file_path, TF_Function** dataset_func,
|
||||
TF_Status* status);
|
||||
TF_CAPI_EXPORT extern TF_Operation* TF_MakeFakeIteratorGetNextWithDatasets(
|
||||
TF_Graph* graph, TF_Status* status);
|
||||
|
||||
// Returns the shape proto of shape {}.
|
||||
TF_CAPI_EXPORT extern void TF_GetAttrScalarTensorShapeProto(TF_Buffer* value,
|
||||
TF_Status* status);
|
||||
// Similar to the above API, except that the returned iterator reads the
|
||||
// TFRecord files from `file_path`.
|
||||
// The iterators outputs 2 tensors:
|
||||
// - A float tensor of shape `batch_size` X 224 X 224 X 3
|
||||
// - An int32 tensor of shape `batch_size`
|
||||
// TODO(hongm): Extend the API to allow customization of the nodes created.
|
||||
TF_CAPI_EXPORT extern TF_Operation* TF_MakeImagenetIteratorGetNextWithDatasets(
|
||||
TF_Graph* graph, const char* file_path, int batch_size, TF_Status* status);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
|
@ -15,38 +15,36 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/c/c_api_experimental.h"
|
||||
#include "tensorflow/c/c_test_util.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
void TestIteratorStack() {
|
||||
void TestFakeIteratorStack() {
|
||||
TF_Status* s = TF_NewStatus();
|
||||
TF_Graph* graph = TF_NewGraph();
|
||||
|
||||
TF_Function* dataset_func = nullptr;
|
||||
|
||||
TF_Operation* get_next =
|
||||
TF_MakeIteratorGetNextWithDatasets(graph, "dummy_path", &dataset_func, s);
|
||||
TF_Operation* get_next = TF_MakeFakeIteratorGetNextWithDatasets(graph, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
ASSERT_NE(dataset_func, nullptr);
|
||||
TF_DeleteFunction(dataset_func);
|
||||
|
||||
CSession csession(graph, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
// Run the graph.
|
||||
for (int i = 0; i < 1; ++i) {
|
||||
const float base_value = 42.0;
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
csession.SetOutputs({get_next});
|
||||
csession.Run(s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_Tensor* out = csession.output_tensor(0);
|
||||
ASSERT_TRUE(out != nullptr);
|
||||
EXPECT_EQ(TF_INT32, TF_TensorType(out));
|
||||
EXPECT_EQ(0, TF_NumDims(out)); // scalar
|
||||
ASSERT_EQ(sizeof(int32), TF_TensorByteSize(out));
|
||||
int32* output_contents = static_cast<int32*>(TF_TensorData(out));
|
||||
EXPECT_EQ(1, *output_contents);
|
||||
ASSERT_EQ(TF_FLOAT, TF_TensorType(out));
|
||||
ASSERT_EQ(0, TF_NumDims(out)); // scalar
|
||||
ASSERT_EQ(sizeof(float), TF_TensorByteSize(out));
|
||||
float* output_contents = static_cast<float*>(TF_TensorData(out));
|
||||
ASSERT_EQ(base_value + i, *output_contents);
|
||||
}
|
||||
|
||||
// This should error out since we've exhausted the iterator.
|
||||
@ -60,7 +58,63 @@ void TestIteratorStack() {
|
||||
TF_DeleteStatus(s);
|
||||
}
|
||||
|
||||
TEST(CAPI_EXPERIMENTAL, IteratorGetNext) { TestIteratorStack(); }
|
||||
TEST(CAPI_EXPERIMENTAL, FakeIteratorGetNext) { TestFakeIteratorStack(); }
|
||||
|
||||
TEST(CAPI_EXPERIMENTAL, ImagenetIteratorGetNext) {
|
||||
TF_Status* s = TF_NewStatus();
|
||||
TF_Graph* graph = TF_NewGraph();
|
||||
|
||||
const string file_path = tensorflow::io::JoinPath(
|
||||
tensorflow::testing::TensorFlowSrcRoot(), "c/testdata/tf_record");
|
||||
VLOG(1) << "data file path is " << file_path;
|
||||
const int batch_size = 64;
|
||||
TF_Operation* get_next = TF_MakeImagenetIteratorGetNextWithDatasets(
|
||||
graph, file_path.c_str(), batch_size, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
CSession csession(graph, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
// Run the graph.
|
||||
// The two output tensors should look like:
|
||||
// Tensor("IteratorGetNext:0", shape=(batch_size, 224, 224, 3), dtype=float32)
|
||||
// Tensor("IteratorGetNext:1", shape=(batch_size, ), dtype=int32)
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
LOG(INFO) << "Running iter " << i;
|
||||
csession.SetOutputs({{get_next, 0}, {get_next, 1}});
|
||||
csession.Run(s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
{
|
||||
TF_Tensor* image = csession.output_tensor(0);
|
||||
ASSERT_TRUE(image != nullptr);
|
||||
ASSERT_EQ(TF_FLOAT, TF_TensorType(image));
|
||||
// Confirm shape is 224 X 224 X 3
|
||||
ASSERT_EQ(4, TF_NumDims(image));
|
||||
ASSERT_EQ(batch_size, TF_Dim(image, 0));
|
||||
ASSERT_EQ(224, TF_Dim(image, 1));
|
||||
ASSERT_EQ(224, TF_Dim(image, 2));
|
||||
ASSERT_EQ(3, TF_Dim(image, 3));
|
||||
ASSERT_EQ(sizeof(float) * batch_size * 224 * 224 * 3,
|
||||
TF_TensorByteSize(image));
|
||||
}
|
||||
|
||||
{
|
||||
TF_Tensor* label = csession.output_tensor(1);
|
||||
ASSERT_TRUE(label != nullptr);
|
||||
ASSERT_EQ(TF_INT32, TF_TensorType(label));
|
||||
ASSERT_EQ(1, TF_NumDims(label));
|
||||
ASSERT_EQ(batch_size, TF_Dim(label, 0));
|
||||
ASSERT_EQ(sizeof(int32) * batch_size, TF_TensorByteSize(label));
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up
|
||||
csession.CloseAndDelete(s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
TF_DeleteGraph(graph);
|
||||
TF_DeleteStatus(s);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
BIN
tensorflow/c/testdata/tf_record
vendored
Normal file
BIN
tensorflow/c/testdata/tf_record
vendored
Normal file
Binary file not shown.
Loading…
Reference in New Issue
Block a user