Support two data types as inputs in RemoteFusedGraphExecuteOp

Change: 152843285
This commit is contained in:
A. Unique TensorFlower 2017-04-11 11:12:15 -08:00 committed by TensorFlower Gardener
parent 740053be50
commit a11e669c5a
7 changed files with 82 additions and 54 deletions

View File

@ -50,6 +50,7 @@ tensorflow/core/kernels/hexagon/graph_transferer_test.cc \
tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc \
tensorflow/core/kernels/hexagon/hexagon_ops_definitions.cc \
tensorflow/core/kernels/remote_fused_graph_execute_op.cc \
tensorflow/core/kernels/remote_fused_graph_execute_utils.cc \
tensorflow/core/ops/remote_fused_graph_ops.cc \
tensorflow/core/platform/posix/test.cc

View File

@ -107,18 +107,14 @@ GraphTransferUtils::BuildRemoteFusedGraphExecuteInfo(
AddOutputTensorShapeTypeByTensorShapeMap(tensor_shape_map, &node_def));
}
CHECK(status.ok());
const DataType input_data_type =
inputs.empty() ? DT_FLOAT : inputs.at(0).second.dtype();
Scope root = Scope::NewRootScope();
std::vector<Output> output_list;
DataTypeVector input_types;
for (const std::pair<string, Tensor>& input_node_info : inputs) {
const Scope& scope = root.WithOpName(input_node_info.first);
Node* ret;
const auto unique_name = scope.GetUniqueNameForOp("PlaceholderV2");
const DataType dt = input_node_info.second.dtype();
// DataType of input arguments should be same.
CHECK_EQ(input_data_type, dt);
auto builder = NodeBuilder(unique_name, "PlaceholderV2")
.Attr("dtype", input_node_info.second.dtype())
.Attr("shape", input_node_info.second.shape());
@ -126,25 +122,21 @@ GraphTransferUtils::BuildRemoteFusedGraphExecuteInfo(
scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
CHECK(scope.ok());
output_list.emplace_back(Output(ret, 0));
input_types.push_back(input_node_info.second.dtype());
}
const RemoteFusedGraphExecuteInfo execute_info =
BuildRemoteFusedGraphExecuteInfo(*original_def, inputs, outputs,
tensor_shape_map);
const std::pair<DataType, TensorShape>* tensor_shape_type =
RemoteFusedGraphExecuteUtils::GetTensorShapeType(tensor_shape_map,
outputs.at(0));
CHECK_NE(tensor_shape_type, nullptr);
const DataType output_data_type = tensor_shape_type->first;
DataTypeVector output_types;
// Sanity-check to confirm all output data types are same.
for (const string& output_node_name : outputs) {
const std::pair<DataType, TensorShape>* tst =
RemoteFusedGraphExecuteUtils::GetTensorShapeType(tensor_shape_map,
output_node_name);
CHECK_NE(tst, nullptr);
const DataType dt = tensor_shape_type->first;
CHECK_EQ(output_data_type, dt);
output_types.push_back(tst->first);
}
const Scope& scope = root.WithOpName(remote_graph_execute_name);
@ -152,18 +144,17 @@ GraphTransferUtils::BuildRemoteFusedGraphExecuteInfo(
auto node_out_list = ops::AsNodeOutList(scope, InputList(output_list));
Node* node;
const auto unique_name = scope.GetUniqueNameForOp("RemoteFusedGraphExecute");
auto builder = NodeBuilder(unique_name, "RemoteFusedGraphExecute")
.Input(node_out_list)
.Attr("M", static_cast<int64>(output_list.size()))
.Attr("N", static_cast<int64>(outputs.size()))
.Attr("T", input_data_type)
.Attr("U", output_data_type)
.Attr("serialized_graph_transfer_info",
.Attr("Tinputs", input_types)
.Attr("Toutputs", output_types)
.Attr("serialized_remote_fused_graph_execute_info",
StringPiece(execute_info.SerializeAsString()));
CHECK(scope.ok());
scope.UpdateBuilder(&builder);
scope.UpdateStatus(builder.Finalize(scope.graph(), &node));
CHECK(scope.ok());
CHECK(scope.ok()) << scope.status();
GraphDef fusedGraphDef;
TF_CHECK_OK(root.ToGraphDef(&fusedGraphDef));

View File

@ -29,8 +29,11 @@ class RemoteFusedGraphExecuteOp : public OpKernel {
explicit RemoteFusedGraphExecuteOp(OpKernelConstruction* const ctx)
: OpKernel(ctx), execute_info_() {
string serialized_proto;
OP_REQUIRES_OK(
ctx, ctx->GetAttr("serialized_graph_transfer_info", &serialized_proto));
OP_REQUIRES_OK(ctx,
ctx->GetAttr("serialized_remote_fused_graph_execute_info",
&serialized_proto));
OP_REQUIRES_OK(ctx, ctx->GetAttr("Tinputs", &input_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("Toutputs", &output_types_));
execute_info_.ParseFromString(serialized_proto);
if (!execute_info_.executor_name().empty()) {
const RemoteFusedGraphExecuteUtils::ExecutorBuildFunc* build_func =
@ -69,12 +72,15 @@ class RemoteFusedGraphExecuteOp : public OpKernel {
void Compute(OpKernelContext* const ctx) final {
CHECK(ctx != nullptr);
const int input_count = ctx->num_inputs();
CHECK(input_count == execute_info_.graph_input_node_name_size())
const int graph_input_count = execute_info_.graph_input_node_name_size();
CHECK(input_count == graph_input_count &&
input_count == input_types_.size())
<< "input_count = " << input_count
<< ", gt input count = " << execute_info_.graph_input_node_name_size();
<< ", gt input count = " << execute_info_.graph_input_node_name_size()
<< ", type count = " << input_types_.size();
// 3. Send inputs into remote processor
for (int i = 0; i < input_count; ++i) {
// 3. Send first data type inputs into remote processor
for (int i = 0; i < graph_input_count; ++i) {
const Tensor& input_tensor = ctx->input(i);
const string& input_node_name = execute_info_.graph_input_node_name(i);
if (remote_fused_graph_executor_) {
@ -90,7 +96,8 @@ class RemoteFusedGraphExecuteOp : public OpKernel {
// 5. Load outputs from remote processor
const int output_count = ctx->num_outputs();
CHECK(output_count == execute_info_.graph_output_node_name_size());
CHECK(output_count == execute_info_.graph_output_node_name_size() &&
output_count == output_types_.size());
for (int i = 0; i < output_count; ++i) {
Tensor* output = nullptr;
const string& output_node_name = execute_info_.graph_output_node_name(i);
@ -110,6 +117,8 @@ class RemoteFusedGraphExecuteOp : public OpKernel {
private:
RemoteFusedGraphExecuteInfo execute_info_;
std::unique_ptr<IRemoteFusedGraphExecutor> remote_fused_graph_executor_;
DataTypeVector input_types_;
DataTypeVector output_types_;
TF_DISALLOW_COPY_AND_ASSIGN(RemoteFusedGraphExecuteOp);
};

View File

@ -37,20 +37,34 @@ namespace tensorflow {
class RemoteFusedGraphExecuteTest : public OpsTestBase {};
TEST_F(RemoteFusedGraphExecuteTest, ExecuteAddGraph) {
TEST_F(RemoteFusedGraphExecuteTest, BuildModelWithOneDataType) {
DataTypeVector input_types({DT_FLOAT, DT_FLOAT});
DataTypeVector output_types({DT_FLOAT});
TF_ASSERT_OK(
NodeDefBuilder("remote_fused_graph_execute_op", "RemoteFusedGraphExecute")
.Input(FakeInput(2, DT_FLOAT))
.Attr("M", 2)
.Attr("N", 1)
.Attr("T", DataTypeToEnum<float>::v())
.Attr("U", DataTypeToEnum<float>::v())
.Attr("serialized_graph_transfer_info", "")
.Attr("Tinputs", input_types)
.Attr("Toutputs", output_types)
.Attr("serialized_remote_fused_graph_execute_info", "")
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
// TODO(satok): Add benchmark
}
TEST_F(RemoteFusedGraphExecuteTest, BuildModelWithWrongDataType) {
DataTypeVector input_types({DT_INT32, DT_INT32});
DataTypeVector output_types({DT_FLOAT});
ASSERT_FALSE(
NodeDefBuilder("remote_fused_graph_execute_op", "RemoteFusedGraphExecute")
.Input(FakeInput(2, DT_FLOAT))
.Attr("Tinputs", input_types)
.Attr("Toutputs", output_types)
.Attr("serialized_remote_fused_graph_execute_info", "")
.Finalize(node_def())
.ok());
// TODO(satok): Add benchmark
}
////////////////////////////
// End-to-end test: Begin //
////////////////////////////
@ -94,13 +108,15 @@ static Output BuildRemoteFusedGraphExecuteOp(
CHECK(scope.ok());
auto node_out_list = ops::AsNodeOutList(scope, InputList(output_list));
const auto unique_name = scope.GetUniqueNameForOp("RemoteFusedGraphExecute");
DataTypeVector input_types{DT_FLOAT};
DataTypeVector output_types{DT_FLOAT};
auto builder = NodeBuilder(unique_name, "RemoteFusedGraphExecute")
.Input(node_out_list)
.Attr("M", static_cast<int64>(output_list.size()))
.Attr("N", static_cast<int64>(output_node_count))
.Attr("T", DT_FLOAT)
.Attr("U", DT_FLOAT)
.Attr("serialized_graph_transfer_info",
.Attr("Tinputs", input_types)
.Attr("Toutputs", output_types)
.Attr("serialized_remote_fused_graph_execute_info",
StringPiece(execute_info.SerializeAsString()));
CHECK(scope.ok());
scope.UpdateBuilder(&builder);

View File

@ -187,7 +187,8 @@ RemoteFusedGraphExecuteUtils::GetExecutorBuildRegistry() {
const std::vector<std::pair<string, Tensor>>& input_tensor_vector,
const string& node_name) {
for (const std::pair<string, Tensor>& pair : input_tensor_vector) {
if (node_name == pair.first) {
const TensorId tid = ParseTensorName(pair.first);
if (node_name == tid.first.ToString()) {
return true;
}
}

View File

@ -21,13 +21,11 @@ namespace tensorflow {
// TODO(satok): Implement shape_inference
REGISTER_OP("RemoteFusedGraphExecute")
.Input("values: M * T")
.Output("output: N * U")
.Attr("M: int >= 0")
.Attr("N: int >= 0")
.Attr("T: type")
.Attr("U: type")
.Attr("serialized_graph_transfer_info: string")
.Input("inputs: Tinputs")
.Output("outputs: Toutputs")
.Attr("Tinputs: list(type) >= 0")
.Attr("Toutputs: list(type) >= 0")
.Attr("serialized_remote_fused_graph_execute_info: string")
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
Execute a sub graph on a remote processor transferred by GraphTransferer.

View File

@ -26,21 +26,33 @@ namespace tensorflow {
TEST(RemoteFusedGraphOpsTest, RemoteFusedGraphExecute_ShapeFn) {
ShapeInferenceTestOp op("RemoteFusedGraphExecute");
auto set_n = [&op](int input_count, int output_count) {
auto set_n = [&op](int input1_count, int input2_count, int output_count) {
std::vector<NodeDefBuilder::NodeOut> src_list;
for (int i = 0; i < input_count; ++i) {
DataTypeVector input_types;
for (int i = 0; i < input1_count; ++i) {
src_list.emplace_back("a", 0, DT_FLOAT);
input_types.emplace_back(DT_FLOAT);
}
TF_ASSERT_OK(NodeDefBuilder("test", "RemoteFusedGraphExecute")
.Input(src_list)
.Attr("M", input_count)
.Attr("N", output_count)
.Attr("T", DT_FLOAT)
.Attr("U", DT_FLOAT)
.Finalize(&op.node_def));
for (int i = 0; i < input2_count; ++i) {
src_list.emplace_back("b", 0, DT_INT32);
input_types.emplace_back(DT_INT32);
}
DataTypeVector output_types;
for (int i = 0; i < output_count; ++i) {
output_types.emplace_back(DT_FLOAT);
}
NodeDefBuilder builder = NodeDefBuilder("test", "RemoteFusedGraphExecute")
.Input(src_list)
.Attr("Tinputs", input_types)
.Attr("Toutputs", output_types);
TF_ASSERT_OK(builder.Finalize(&op.node_def));
};
set_n(4, 2);
set_n(4, 0, 2);
INFER_OK(op, "?;?;?;?", "?;?"); // output rank unknown
set_n(4, 3, 3);
INFER_OK(op, "?;?;?;?;?;?;?", "?;?;?"); // output rank unknown
// TODO(satok): Implement shape inference and do its test here
}