Added experimental C APIs to build a stack of dataset + iterator nodes that
reads imagenet TFRecord files. PiperOrigin-RevId: 190488817
This commit is contained in:
parent
cc6b2ae837
commit
be917027e3
@ -220,6 +220,7 @@ tf_cc_test(
|
|||||||
name = "c_api_experimental_test",
|
name = "c_api_experimental_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
srcs = ["c_api_experimental_test.cc"],
|
srcs = ["c_api_experimental_test.cc"],
|
||||||
|
data = ["testdata/tf_record"],
|
||||||
linkopts = select({
|
linkopts = select({
|
||||||
"//tensorflow:darwin": ["-headerpad_max_install_names"],
|
"//tensorflow:darwin": ["-headerpad_max_install_names"],
|
||||||
"//conditions:default": [],
|
"//conditions:default": [],
|
||||||
@ -230,6 +231,7 @@ tf_cc_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":c_api_experimental",
|
":c_api_experimental",
|
||||||
":c_test_util",
|
":c_test_util",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
],
|
],
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -87,25 +87,22 @@ TF_CAPI_EXPORT extern const char* TF_GraphDebugString(TF_Graph* graph,
|
|||||||
TF_CAPI_EXPORT extern const char* TF_GraphDebugString(TF_Graph* graph,
|
TF_CAPI_EXPORT extern const char* TF_GraphDebugString(TF_Graph* graph,
|
||||||
size_t* len);
|
size_t* len);
|
||||||
|
|
||||||
// Creates a stack of data set + iterator nodes reading the TFRecord files from
|
// Creates a stack of data set + iterator nodes, currently hard-coded to return
|
||||||
// `file_path`, and outputs the following info on success:
|
// a sequence of 3 float values <42.0, 43.0, 44.0> over 3 calls. On success,
|
||||||
|
// returns the IteratorGetNext node, which caller can run or feed into an node.
|
||||||
//
|
//
|
||||||
// 1. Returns the IteratorGetNext node, which caller can run or feed into an
|
|
||||||
// node.
|
|
||||||
//
|
|
||||||
// 2. Sets `dataset_func` to the created function that encapsulates the data set
|
|
||||||
// nodes. Caller owns that function, and must call TF_DeleteFunction() on it.
|
|
||||||
//
|
|
||||||
//
|
|
||||||
// The nodes are currently hard-coded to return a single Int32 of value 1.
|
|
||||||
// TODO(hongm): Extend the API to allow customization of the nodes created.
|
// TODO(hongm): Extend the API to allow customization of the nodes created.
|
||||||
TF_CAPI_EXPORT extern TF_Operation* TF_MakeIteratorGetNextWithDatasets(
|
TF_CAPI_EXPORT extern TF_Operation* TF_MakeFakeIteratorGetNextWithDatasets(
|
||||||
TF_Graph* graph, const char* file_path, TF_Function** dataset_func,
|
TF_Graph* graph, TF_Status* status);
|
||||||
TF_Status* status);
|
|
||||||
|
|
||||||
// Returns the shape proto of shape {}.
|
// Similar to the above API, except that the returned iterator reads the
|
||||||
TF_CAPI_EXPORT extern void TF_GetAttrScalarTensorShapeProto(TF_Buffer* value,
|
// TFRecord files from `file_path`.
|
||||||
TF_Status* status);
|
// The iterators outputs 2 tensors:
|
||||||
|
// - A float tensor of shape `batch_size` X 224 X 224 X 3
|
||||||
|
// - An int32 tensor of shape `batch_size`
|
||||||
|
// TODO(hongm): Extend the API to allow customization of the nodes created.
|
||||||
|
TF_CAPI_EXPORT extern TF_Operation* TF_MakeImagenetIteratorGetNextWithDatasets(
|
||||||
|
TF_Graph* graph, const char* file_path, int batch_size, TF_Status* status);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
} /* end extern "C" */
|
} /* end extern "C" */
|
||||||
|
@ -15,38 +15,36 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/c/c_api_experimental.h"
|
#include "tensorflow/c/c_api_experimental.h"
|
||||||
#include "tensorflow/c/c_test_util.h"
|
#include "tensorflow/c/c_test_util.h"
|
||||||
|
#include "tensorflow/core/lib/io/path.h"
|
||||||
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
void TestIteratorStack() {
|
void TestFakeIteratorStack() {
|
||||||
TF_Status* s = TF_NewStatus();
|
TF_Status* s = TF_NewStatus();
|
||||||
TF_Graph* graph = TF_NewGraph();
|
TF_Graph* graph = TF_NewGraph();
|
||||||
|
|
||||||
TF_Function* dataset_func = nullptr;
|
TF_Operation* get_next = TF_MakeFakeIteratorGetNextWithDatasets(graph, s);
|
||||||
|
|
||||||
TF_Operation* get_next =
|
|
||||||
TF_MakeIteratorGetNextWithDatasets(graph, "dummy_path", &dataset_func, s);
|
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
|
||||||
ASSERT_NE(dataset_func, nullptr);
|
|
||||||
TF_DeleteFunction(dataset_func);
|
|
||||||
|
|
||||||
CSession csession(graph, s);
|
CSession csession(graph, s);
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
|
||||||
// Run the graph.
|
// Run the graph.
|
||||||
for (int i = 0; i < 1; ++i) {
|
const float base_value = 42.0;
|
||||||
|
for (int i = 0; i < 3; ++i) {
|
||||||
csession.SetOutputs({get_next});
|
csession.SetOutputs({get_next});
|
||||||
csession.Run(s);
|
csession.Run(s);
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
TF_Tensor* out = csession.output_tensor(0);
|
TF_Tensor* out = csession.output_tensor(0);
|
||||||
ASSERT_TRUE(out != nullptr);
|
ASSERT_TRUE(out != nullptr);
|
||||||
EXPECT_EQ(TF_INT32, TF_TensorType(out));
|
ASSERT_EQ(TF_FLOAT, TF_TensorType(out));
|
||||||
EXPECT_EQ(0, TF_NumDims(out)); // scalar
|
ASSERT_EQ(0, TF_NumDims(out)); // scalar
|
||||||
ASSERT_EQ(sizeof(int32), TF_TensorByteSize(out));
|
ASSERT_EQ(sizeof(float), TF_TensorByteSize(out));
|
||||||
int32* output_contents = static_cast<int32*>(TF_TensorData(out));
|
float* output_contents = static_cast<float*>(TF_TensorData(out));
|
||||||
EXPECT_EQ(1, *output_contents);
|
ASSERT_EQ(base_value + i, *output_contents);
|
||||||
}
|
}
|
||||||
|
|
||||||
// This should error out since we've exhausted the iterator.
|
// This should error out since we've exhausted the iterator.
|
||||||
@ -60,7 +58,63 @@ void TestIteratorStack() {
|
|||||||
TF_DeleteStatus(s);
|
TF_DeleteStatus(s);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(CAPI_EXPERIMENTAL, IteratorGetNext) { TestIteratorStack(); }
|
TEST(CAPI_EXPERIMENTAL, FakeIteratorGetNext) { TestFakeIteratorStack(); }
|
||||||
|
|
||||||
|
TEST(CAPI_EXPERIMENTAL, ImagenetIteratorGetNext) {
|
||||||
|
TF_Status* s = TF_NewStatus();
|
||||||
|
TF_Graph* graph = TF_NewGraph();
|
||||||
|
|
||||||
|
const string file_path = tensorflow::io::JoinPath(
|
||||||
|
tensorflow::testing::TensorFlowSrcRoot(), "c/testdata/tf_record");
|
||||||
|
VLOG(1) << "data file path is " << file_path;
|
||||||
|
const int batch_size = 64;
|
||||||
|
TF_Operation* get_next = TF_MakeImagenetIteratorGetNextWithDatasets(
|
||||||
|
graph, file_path.c_str(), batch_size, s);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
|
||||||
|
CSession csession(graph, s);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
|
||||||
|
// Run the graph.
|
||||||
|
// The two output tensors should look like:
|
||||||
|
// Tensor("IteratorGetNext:0", shape=(batch_size, 224, 224, 3), dtype=float32)
|
||||||
|
// Tensor("IteratorGetNext:1", shape=(batch_size, ), dtype=int32)
|
||||||
|
for (int i = 0; i < 3; ++i) {
|
||||||
|
LOG(INFO) << "Running iter " << i;
|
||||||
|
csession.SetOutputs({{get_next, 0}, {get_next, 1}});
|
||||||
|
csession.Run(s);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
|
||||||
|
{
|
||||||
|
TF_Tensor* image = csession.output_tensor(0);
|
||||||
|
ASSERT_TRUE(image != nullptr);
|
||||||
|
ASSERT_EQ(TF_FLOAT, TF_TensorType(image));
|
||||||
|
// Confirm shape is 224 X 224 X 3
|
||||||
|
ASSERT_EQ(4, TF_NumDims(image));
|
||||||
|
ASSERT_EQ(batch_size, TF_Dim(image, 0));
|
||||||
|
ASSERT_EQ(224, TF_Dim(image, 1));
|
||||||
|
ASSERT_EQ(224, TF_Dim(image, 2));
|
||||||
|
ASSERT_EQ(3, TF_Dim(image, 3));
|
||||||
|
ASSERT_EQ(sizeof(float) * batch_size * 224 * 224 * 3,
|
||||||
|
TF_TensorByteSize(image));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
TF_Tensor* label = csession.output_tensor(1);
|
||||||
|
ASSERT_TRUE(label != nullptr);
|
||||||
|
ASSERT_EQ(TF_INT32, TF_TensorType(label));
|
||||||
|
ASSERT_EQ(1, TF_NumDims(label));
|
||||||
|
ASSERT_EQ(batch_size, TF_Dim(label, 0));
|
||||||
|
ASSERT_EQ(sizeof(int32) * batch_size, TF_TensorByteSize(label));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean up
|
||||||
|
csession.CloseAndDelete(s);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
TF_DeleteGraph(graph);
|
||||||
|
TF_DeleteStatus(s);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
BIN
tensorflow/c/testdata/tf_record
vendored
Normal file
BIN
tensorflow/c/testdata/tf_record
vendored
Normal file
Binary file not shown.
Loading…
Reference in New Issue
Block a user