diff --git a/tensorflow/contrib/makefile/sub_makefiles/quantization/Makefile.in b/tensorflow/contrib/makefile/sub_makefiles/quantization/Makefile.in index bc7a238fdba..6ba41d5d12a 100644 --- a/tensorflow/contrib/makefile/sub_makefiles/quantization/Makefile.in +++ b/tensorflow/contrib/makefile/sub_makefiles/quantization/Makefile.in @@ -50,6 +50,7 @@ tensorflow/core/kernels/hexagon/graph_transferer_test.cc \ tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc \ tensorflow/core/kernels/hexagon/hexagon_ops_definitions.cc \ tensorflow/core/kernels/remote_fused_graph_execute_op.cc \ +tensorflow/core/kernels/remote_fused_graph_execute_utils.cc \ tensorflow/core/ops/remote_fused_graph_ops.cc \ tensorflow/core/platform/posix/test.cc diff --git a/tensorflow/core/kernels/hexagon/graph_transfer_utils.cc b/tensorflow/core/kernels/hexagon/graph_transfer_utils.cc index 4352f13309f..aff78837a32 100644 --- a/tensorflow/core/kernels/hexagon/graph_transfer_utils.cc +++ b/tensorflow/core/kernels/hexagon/graph_transfer_utils.cc @@ -107,18 +107,14 @@ GraphTransferUtils::BuildRemoteFusedGraphExecuteInfo( AddOutputTensorShapeTypeByTensorShapeMap(tensor_shape_map, &node_def)); } CHECK(status.ok()); - const DataType input_data_type = - inputs.empty() ? DT_FLOAT : inputs.at(0).second.dtype(); Scope root = Scope::NewRootScope(); std::vector output_list; + DataTypeVector input_types; for (const std::pair& input_node_info : inputs) { const Scope& scope = root.WithOpName(input_node_info.first); Node* ret; const auto unique_name = scope.GetUniqueNameForOp("PlaceholderV2"); - const DataType dt = input_node_info.second.dtype(); - // DataType of input arguments should be same. - CHECK_EQ(input_data_type, dt); auto builder = NodeBuilder(unique_name, "PlaceholderV2") .Attr("dtype", input_node_info.second.dtype()) .Attr("shape", input_node_info.second.shape()); @@ -126,25 +122,21 @@ GraphTransferUtils::BuildRemoteFusedGraphExecuteInfo( scope.UpdateStatus(builder.Finalize(scope.graph(), &ret)); CHECK(scope.ok()); output_list.emplace_back(Output(ret, 0)); + input_types.push_back(input_node_info.second.dtype()); } const RemoteFusedGraphExecuteInfo execute_info = BuildRemoteFusedGraphExecuteInfo(*original_def, inputs, outputs, tensor_shape_map); - const std::pair* tensor_shape_type = - RemoteFusedGraphExecuteUtils::GetTensorShapeType(tensor_shape_map, - outputs.at(0)); - CHECK_NE(tensor_shape_type, nullptr); - const DataType output_data_type = tensor_shape_type->first; + DataTypeVector output_types; // Sanity-check to confirm all output data types are same. for (const string& output_node_name : outputs) { const std::pair* tst = RemoteFusedGraphExecuteUtils::GetTensorShapeType(tensor_shape_map, output_node_name); CHECK_NE(tst, nullptr); - const DataType dt = tensor_shape_type->first; - CHECK_EQ(output_data_type, dt); + output_types.push_back(tst->first); } const Scope& scope = root.WithOpName(remote_graph_execute_name); @@ -152,18 +144,17 @@ GraphTransferUtils::BuildRemoteFusedGraphExecuteInfo( auto node_out_list = ops::AsNodeOutList(scope, InputList(output_list)); Node* node; const auto unique_name = scope.GetUniqueNameForOp("RemoteFusedGraphExecute"); + auto builder = NodeBuilder(unique_name, "RemoteFusedGraphExecute") .Input(node_out_list) - .Attr("M", static_cast(output_list.size())) - .Attr("N", static_cast(outputs.size())) - .Attr("T", input_data_type) - .Attr("U", output_data_type) - .Attr("serialized_graph_transfer_info", + .Attr("Tinputs", input_types) + .Attr("Toutputs", output_types) + .Attr("serialized_remote_fused_graph_execute_info", StringPiece(execute_info.SerializeAsString())); CHECK(scope.ok()); scope.UpdateBuilder(&builder); scope.UpdateStatus(builder.Finalize(scope.graph(), &node)); - CHECK(scope.ok()); + CHECK(scope.ok()) << scope.status(); GraphDef fusedGraphDef; TF_CHECK_OK(root.ToGraphDef(&fusedGraphDef)); diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_op.cc b/tensorflow/core/kernels/remote_fused_graph_execute_op.cc index bd95474a62b..101d0e694b2 100644 --- a/tensorflow/core/kernels/remote_fused_graph_execute_op.cc +++ b/tensorflow/core/kernels/remote_fused_graph_execute_op.cc @@ -29,8 +29,11 @@ class RemoteFusedGraphExecuteOp : public OpKernel { explicit RemoteFusedGraphExecuteOp(OpKernelConstruction* const ctx) : OpKernel(ctx), execute_info_() { string serialized_proto; - OP_REQUIRES_OK( - ctx, ctx->GetAttr("serialized_graph_transfer_info", &serialized_proto)); + OP_REQUIRES_OK(ctx, + ctx->GetAttr("serialized_remote_fused_graph_execute_info", + &serialized_proto)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("Tinputs", &input_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("Toutputs", &output_types_)); execute_info_.ParseFromString(serialized_proto); if (!execute_info_.executor_name().empty()) { const RemoteFusedGraphExecuteUtils::ExecutorBuildFunc* build_func = @@ -69,12 +72,15 @@ class RemoteFusedGraphExecuteOp : public OpKernel { void Compute(OpKernelContext* const ctx) final { CHECK(ctx != nullptr); const int input_count = ctx->num_inputs(); - CHECK(input_count == execute_info_.graph_input_node_name_size()) + const int graph_input_count = execute_info_.graph_input_node_name_size(); + CHECK(input_count == graph_input_count && + input_count == input_types_.size()) << "input_count = " << input_count - << ", gt input count = " << execute_info_.graph_input_node_name_size(); + << ", gt input count = " << execute_info_.graph_input_node_name_size() + << ", type count = " << input_types_.size(); - // 3. Send inputs into remote processor - for (int i = 0; i < input_count; ++i) { + // 3. Send first data type inputs into remote processor + for (int i = 0; i < graph_input_count; ++i) { const Tensor& input_tensor = ctx->input(i); const string& input_node_name = execute_info_.graph_input_node_name(i); if (remote_fused_graph_executor_) { @@ -90,7 +96,8 @@ class RemoteFusedGraphExecuteOp : public OpKernel { // 5. Load outputs from remote processor const int output_count = ctx->num_outputs(); - CHECK(output_count == execute_info_.graph_output_node_name_size()); + CHECK(output_count == execute_info_.graph_output_node_name_size() && + output_count == output_types_.size()); for (int i = 0; i < output_count; ++i) { Tensor* output = nullptr; const string& output_node_name = execute_info_.graph_output_node_name(i); @@ -110,6 +117,8 @@ class RemoteFusedGraphExecuteOp : public OpKernel { private: RemoteFusedGraphExecuteInfo execute_info_; std::unique_ptr remote_fused_graph_executor_; + DataTypeVector input_types_; + DataTypeVector output_types_; TF_DISALLOW_COPY_AND_ASSIGN(RemoteFusedGraphExecuteOp); }; diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_op_test.cc b/tensorflow/core/kernels/remote_fused_graph_execute_op_test.cc index 580be4b7db9..925af1f79e2 100644 --- a/tensorflow/core/kernels/remote_fused_graph_execute_op_test.cc +++ b/tensorflow/core/kernels/remote_fused_graph_execute_op_test.cc @@ -37,20 +37,34 @@ namespace tensorflow { class RemoteFusedGraphExecuteTest : public OpsTestBase {}; -TEST_F(RemoteFusedGraphExecuteTest, ExecuteAddGraph) { +TEST_F(RemoteFusedGraphExecuteTest, BuildModelWithOneDataType) { + DataTypeVector input_types({DT_FLOAT, DT_FLOAT}); + DataTypeVector output_types({DT_FLOAT}); TF_ASSERT_OK( NodeDefBuilder("remote_fused_graph_execute_op", "RemoteFusedGraphExecute") .Input(FakeInput(2, DT_FLOAT)) - .Attr("M", 2) - .Attr("N", 1) - .Attr("T", DataTypeToEnum::v()) - .Attr("U", DataTypeToEnum::v()) - .Attr("serialized_graph_transfer_info", "") + .Attr("Tinputs", input_types) + .Attr("Toutputs", output_types) + .Attr("serialized_remote_fused_graph_execute_info", "") .Finalize(node_def())); TF_ASSERT_OK(InitOp()); // TODO(satok): Add benchmark } +TEST_F(RemoteFusedGraphExecuteTest, BuildModelWithWrongDataType) { + DataTypeVector input_types({DT_INT32, DT_INT32}); + DataTypeVector output_types({DT_FLOAT}); + ASSERT_FALSE( + NodeDefBuilder("remote_fused_graph_execute_op", "RemoteFusedGraphExecute") + .Input(FakeInput(2, DT_FLOAT)) + .Attr("Tinputs", input_types) + .Attr("Toutputs", output_types) + .Attr("serialized_remote_fused_graph_execute_info", "") + .Finalize(node_def()) + .ok()); + // TODO(satok): Add benchmark +} + //////////////////////////// // End-to-end test: Begin // //////////////////////////// @@ -94,13 +108,15 @@ static Output BuildRemoteFusedGraphExecuteOp( CHECK(scope.ok()); auto node_out_list = ops::AsNodeOutList(scope, InputList(output_list)); const auto unique_name = scope.GetUniqueNameForOp("RemoteFusedGraphExecute"); + + DataTypeVector input_types{DT_FLOAT}; + DataTypeVector output_types{DT_FLOAT}; + auto builder = NodeBuilder(unique_name, "RemoteFusedGraphExecute") .Input(node_out_list) - .Attr("M", static_cast(output_list.size())) - .Attr("N", static_cast(output_node_count)) - .Attr("T", DT_FLOAT) - .Attr("U", DT_FLOAT) - .Attr("serialized_graph_transfer_info", + .Attr("Tinputs", input_types) + .Attr("Toutputs", output_types) + .Attr("serialized_remote_fused_graph_execute_info", StringPiece(execute_info.SerializeAsString())); CHECK(scope.ok()); scope.UpdateBuilder(&builder); diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc b/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc index a6eb06004a8..b885c5ce095 100644 --- a/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc +++ b/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc @@ -187,7 +187,8 @@ RemoteFusedGraphExecuteUtils::GetExecutorBuildRegistry() { const std::vector>& input_tensor_vector, const string& node_name) { for (const std::pair& pair : input_tensor_vector) { - if (node_name == pair.first) { + const TensorId tid = ParseTensorName(pair.first); + if (node_name == tid.first.ToString()) { return true; } } diff --git a/tensorflow/core/ops/remote_fused_graph_ops.cc b/tensorflow/core/ops/remote_fused_graph_ops.cc index 3d90c054d47..6e9f37a6152 100644 --- a/tensorflow/core/ops/remote_fused_graph_ops.cc +++ b/tensorflow/core/ops/remote_fused_graph_ops.cc @@ -21,13 +21,11 @@ namespace tensorflow { // TODO(satok): Implement shape_inference REGISTER_OP("RemoteFusedGraphExecute") - .Input("values: M * T") - .Output("output: N * U") - .Attr("M: int >= 0") - .Attr("N: int >= 0") - .Attr("T: type") - .Attr("U: type") - .Attr("serialized_graph_transfer_info: string") + .Input("inputs: Tinputs") + .Output("outputs: Toutputs") + .Attr("Tinputs: list(type) >= 0") + .Attr("Toutputs: list(type) >= 0") + .Attr("serialized_remote_fused_graph_execute_info: string") .SetShapeFn(shape_inference::UnknownShape) .Doc(R"doc( Execute a sub graph on a remote processor transferred by GraphTransferer. diff --git a/tensorflow/core/ops/remote_fused_graph_ops_test.cc b/tensorflow/core/ops/remote_fused_graph_ops_test.cc index 7fbe213e20f..f5d90a676d7 100644 --- a/tensorflow/core/ops/remote_fused_graph_ops_test.cc +++ b/tensorflow/core/ops/remote_fused_graph_ops_test.cc @@ -26,21 +26,33 @@ namespace tensorflow { TEST(RemoteFusedGraphOpsTest, RemoteFusedGraphExecute_ShapeFn) { ShapeInferenceTestOp op("RemoteFusedGraphExecute"); - auto set_n = [&op](int input_count, int output_count) { + auto set_n = [&op](int input1_count, int input2_count, int output_count) { std::vector src_list; - for (int i = 0; i < input_count; ++i) { + DataTypeVector input_types; + for (int i = 0; i < input1_count; ++i) { src_list.emplace_back("a", 0, DT_FLOAT); + input_types.emplace_back(DT_FLOAT); } - TF_ASSERT_OK(NodeDefBuilder("test", "RemoteFusedGraphExecute") - .Input(src_list) - .Attr("M", input_count) - .Attr("N", output_count) - .Attr("T", DT_FLOAT) - .Attr("U", DT_FLOAT) - .Finalize(&op.node_def)); + for (int i = 0; i < input2_count; ++i) { + src_list.emplace_back("b", 0, DT_INT32); + input_types.emplace_back(DT_INT32); + } + DataTypeVector output_types; + for (int i = 0; i < output_count; ++i) { + output_types.emplace_back(DT_FLOAT); + } + NodeDefBuilder builder = NodeDefBuilder("test", "RemoteFusedGraphExecute") + .Input(src_list) + .Attr("Tinputs", input_types) + .Attr("Toutputs", output_types); + TF_ASSERT_OK(builder.Finalize(&op.node_def)); }; - set_n(4, 2); + set_n(4, 0, 2); INFER_OK(op, "?;?;?;?", "?;?"); // output rank unknown + + set_n(4, 3, 3); + INFER_OK(op, "?;?;?;?;?;?;?", "?;?;?"); // output rank unknown + // TODO(satok): Implement shape inference and do its test here }