Support two data types as inputs in RemoteFusedGraphExecuteOp
Change: 152843285
This commit is contained in:
parent
740053be50
commit
a11e669c5a
@ -50,6 +50,7 @@ tensorflow/core/kernels/hexagon/graph_transferer_test.cc \
|
||||
tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc \
|
||||
tensorflow/core/kernels/hexagon/hexagon_ops_definitions.cc \
|
||||
tensorflow/core/kernels/remote_fused_graph_execute_op.cc \
|
||||
tensorflow/core/kernels/remote_fused_graph_execute_utils.cc \
|
||||
tensorflow/core/ops/remote_fused_graph_ops.cc \
|
||||
tensorflow/core/platform/posix/test.cc
|
||||
|
||||
|
@ -107,18 +107,14 @@ GraphTransferUtils::BuildRemoteFusedGraphExecuteInfo(
|
||||
AddOutputTensorShapeTypeByTensorShapeMap(tensor_shape_map, &node_def));
|
||||
}
|
||||
CHECK(status.ok());
|
||||
const DataType input_data_type =
|
||||
inputs.empty() ? DT_FLOAT : inputs.at(0).second.dtype();
|
||||
|
||||
Scope root = Scope::NewRootScope();
|
||||
std::vector<Output> output_list;
|
||||
DataTypeVector input_types;
|
||||
for (const std::pair<string, Tensor>& input_node_info : inputs) {
|
||||
const Scope& scope = root.WithOpName(input_node_info.first);
|
||||
Node* ret;
|
||||
const auto unique_name = scope.GetUniqueNameForOp("PlaceholderV2");
|
||||
const DataType dt = input_node_info.second.dtype();
|
||||
// DataType of input arguments should be same.
|
||||
CHECK_EQ(input_data_type, dt);
|
||||
auto builder = NodeBuilder(unique_name, "PlaceholderV2")
|
||||
.Attr("dtype", input_node_info.second.dtype())
|
||||
.Attr("shape", input_node_info.second.shape());
|
||||
@ -126,25 +122,21 @@ GraphTransferUtils::BuildRemoteFusedGraphExecuteInfo(
|
||||
scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
|
||||
CHECK(scope.ok());
|
||||
output_list.emplace_back(Output(ret, 0));
|
||||
input_types.push_back(input_node_info.second.dtype());
|
||||
}
|
||||
|
||||
const RemoteFusedGraphExecuteInfo execute_info =
|
||||
BuildRemoteFusedGraphExecuteInfo(*original_def, inputs, outputs,
|
||||
tensor_shape_map);
|
||||
|
||||
const std::pair<DataType, TensorShape>* tensor_shape_type =
|
||||
RemoteFusedGraphExecuteUtils::GetTensorShapeType(tensor_shape_map,
|
||||
outputs.at(0));
|
||||
CHECK_NE(tensor_shape_type, nullptr);
|
||||
const DataType output_data_type = tensor_shape_type->first;
|
||||
DataTypeVector output_types;
|
||||
// Sanity-check to confirm all output data types are same.
|
||||
for (const string& output_node_name : outputs) {
|
||||
const std::pair<DataType, TensorShape>* tst =
|
||||
RemoteFusedGraphExecuteUtils::GetTensorShapeType(tensor_shape_map,
|
||||
output_node_name);
|
||||
CHECK_NE(tst, nullptr);
|
||||
const DataType dt = tensor_shape_type->first;
|
||||
CHECK_EQ(output_data_type, dt);
|
||||
output_types.push_back(tst->first);
|
||||
}
|
||||
|
||||
const Scope& scope = root.WithOpName(remote_graph_execute_name);
|
||||
@ -152,18 +144,17 @@ GraphTransferUtils::BuildRemoteFusedGraphExecuteInfo(
|
||||
auto node_out_list = ops::AsNodeOutList(scope, InputList(output_list));
|
||||
Node* node;
|
||||
const auto unique_name = scope.GetUniqueNameForOp("RemoteFusedGraphExecute");
|
||||
|
||||
auto builder = NodeBuilder(unique_name, "RemoteFusedGraphExecute")
|
||||
.Input(node_out_list)
|
||||
.Attr("M", static_cast<int64>(output_list.size()))
|
||||
.Attr("N", static_cast<int64>(outputs.size()))
|
||||
.Attr("T", input_data_type)
|
||||
.Attr("U", output_data_type)
|
||||
.Attr("serialized_graph_transfer_info",
|
||||
.Attr("Tinputs", input_types)
|
||||
.Attr("Toutputs", output_types)
|
||||
.Attr("serialized_remote_fused_graph_execute_info",
|
||||
StringPiece(execute_info.SerializeAsString()));
|
||||
CHECK(scope.ok());
|
||||
scope.UpdateBuilder(&builder);
|
||||
scope.UpdateStatus(builder.Finalize(scope.graph(), &node));
|
||||
CHECK(scope.ok());
|
||||
CHECK(scope.ok()) << scope.status();
|
||||
|
||||
GraphDef fusedGraphDef;
|
||||
TF_CHECK_OK(root.ToGraphDef(&fusedGraphDef));
|
||||
|
@ -29,8 +29,11 @@ class RemoteFusedGraphExecuteOp : public OpKernel {
|
||||
explicit RemoteFusedGraphExecuteOp(OpKernelConstruction* const ctx)
|
||||
: OpKernel(ctx), execute_info_() {
|
||||
string serialized_proto;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->GetAttr("serialized_graph_transfer_info", &serialized_proto));
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->GetAttr("serialized_remote_fused_graph_execute_info",
|
||||
&serialized_proto));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("Tinputs", &input_types_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("Toutputs", &output_types_));
|
||||
execute_info_.ParseFromString(serialized_proto);
|
||||
if (!execute_info_.executor_name().empty()) {
|
||||
const RemoteFusedGraphExecuteUtils::ExecutorBuildFunc* build_func =
|
||||
@ -69,12 +72,15 @@ class RemoteFusedGraphExecuteOp : public OpKernel {
|
||||
void Compute(OpKernelContext* const ctx) final {
|
||||
CHECK(ctx != nullptr);
|
||||
const int input_count = ctx->num_inputs();
|
||||
CHECK(input_count == execute_info_.graph_input_node_name_size())
|
||||
const int graph_input_count = execute_info_.graph_input_node_name_size();
|
||||
CHECK(input_count == graph_input_count &&
|
||||
input_count == input_types_.size())
|
||||
<< "input_count = " << input_count
|
||||
<< ", gt input count = " << execute_info_.graph_input_node_name_size();
|
||||
<< ", gt input count = " << execute_info_.graph_input_node_name_size()
|
||||
<< ", type count = " << input_types_.size();
|
||||
|
||||
// 3. Send inputs into remote processor
|
||||
for (int i = 0; i < input_count; ++i) {
|
||||
// 3. Send first data type inputs into remote processor
|
||||
for (int i = 0; i < graph_input_count; ++i) {
|
||||
const Tensor& input_tensor = ctx->input(i);
|
||||
const string& input_node_name = execute_info_.graph_input_node_name(i);
|
||||
if (remote_fused_graph_executor_) {
|
||||
@ -90,7 +96,8 @@ class RemoteFusedGraphExecuteOp : public OpKernel {
|
||||
|
||||
// 5. Load outputs from remote processor
|
||||
const int output_count = ctx->num_outputs();
|
||||
CHECK(output_count == execute_info_.graph_output_node_name_size());
|
||||
CHECK(output_count == execute_info_.graph_output_node_name_size() &&
|
||||
output_count == output_types_.size());
|
||||
for (int i = 0; i < output_count; ++i) {
|
||||
Tensor* output = nullptr;
|
||||
const string& output_node_name = execute_info_.graph_output_node_name(i);
|
||||
@ -110,6 +117,8 @@ class RemoteFusedGraphExecuteOp : public OpKernel {
|
||||
private:
|
||||
RemoteFusedGraphExecuteInfo execute_info_;
|
||||
std::unique_ptr<IRemoteFusedGraphExecutor> remote_fused_graph_executor_;
|
||||
DataTypeVector input_types_;
|
||||
DataTypeVector output_types_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(RemoteFusedGraphExecuteOp);
|
||||
};
|
||||
|
@ -37,20 +37,34 @@ namespace tensorflow {
|
||||
|
||||
class RemoteFusedGraphExecuteTest : public OpsTestBase {};
|
||||
|
||||
TEST_F(RemoteFusedGraphExecuteTest, ExecuteAddGraph) {
|
||||
TEST_F(RemoteFusedGraphExecuteTest, BuildModelWithOneDataType) {
|
||||
DataTypeVector input_types({DT_FLOAT, DT_FLOAT});
|
||||
DataTypeVector output_types({DT_FLOAT});
|
||||
TF_ASSERT_OK(
|
||||
NodeDefBuilder("remote_fused_graph_execute_op", "RemoteFusedGraphExecute")
|
||||
.Input(FakeInput(2, DT_FLOAT))
|
||||
.Attr("M", 2)
|
||||
.Attr("N", 1)
|
||||
.Attr("T", DataTypeToEnum<float>::v())
|
||||
.Attr("U", DataTypeToEnum<float>::v())
|
||||
.Attr("serialized_graph_transfer_info", "")
|
||||
.Attr("Tinputs", input_types)
|
||||
.Attr("Toutputs", output_types)
|
||||
.Attr("serialized_remote_fused_graph_execute_info", "")
|
||||
.Finalize(node_def()));
|
||||
TF_ASSERT_OK(InitOp());
|
||||
// TODO(satok): Add benchmark
|
||||
}
|
||||
|
||||
TEST_F(RemoteFusedGraphExecuteTest, BuildModelWithWrongDataType) {
|
||||
DataTypeVector input_types({DT_INT32, DT_INT32});
|
||||
DataTypeVector output_types({DT_FLOAT});
|
||||
ASSERT_FALSE(
|
||||
NodeDefBuilder("remote_fused_graph_execute_op", "RemoteFusedGraphExecute")
|
||||
.Input(FakeInput(2, DT_FLOAT))
|
||||
.Attr("Tinputs", input_types)
|
||||
.Attr("Toutputs", output_types)
|
||||
.Attr("serialized_remote_fused_graph_execute_info", "")
|
||||
.Finalize(node_def())
|
||||
.ok());
|
||||
// TODO(satok): Add benchmark
|
||||
}
|
||||
|
||||
////////////////////////////
|
||||
// End-to-end test: Begin //
|
||||
////////////////////////////
|
||||
@ -94,13 +108,15 @@ static Output BuildRemoteFusedGraphExecuteOp(
|
||||
CHECK(scope.ok());
|
||||
auto node_out_list = ops::AsNodeOutList(scope, InputList(output_list));
|
||||
const auto unique_name = scope.GetUniqueNameForOp("RemoteFusedGraphExecute");
|
||||
|
||||
DataTypeVector input_types{DT_FLOAT};
|
||||
DataTypeVector output_types{DT_FLOAT};
|
||||
|
||||
auto builder = NodeBuilder(unique_name, "RemoteFusedGraphExecute")
|
||||
.Input(node_out_list)
|
||||
.Attr("M", static_cast<int64>(output_list.size()))
|
||||
.Attr("N", static_cast<int64>(output_node_count))
|
||||
.Attr("T", DT_FLOAT)
|
||||
.Attr("U", DT_FLOAT)
|
||||
.Attr("serialized_graph_transfer_info",
|
||||
.Attr("Tinputs", input_types)
|
||||
.Attr("Toutputs", output_types)
|
||||
.Attr("serialized_remote_fused_graph_execute_info",
|
||||
StringPiece(execute_info.SerializeAsString()));
|
||||
CHECK(scope.ok());
|
||||
scope.UpdateBuilder(&builder);
|
||||
|
@ -187,7 +187,8 @@ RemoteFusedGraphExecuteUtils::GetExecutorBuildRegistry() {
|
||||
const std::vector<std::pair<string, Tensor>>& input_tensor_vector,
|
||||
const string& node_name) {
|
||||
for (const std::pair<string, Tensor>& pair : input_tensor_vector) {
|
||||
if (node_name == pair.first) {
|
||||
const TensorId tid = ParseTensorName(pair.first);
|
||||
if (node_name == tid.first.ToString()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
@ -21,13 +21,11 @@ namespace tensorflow {
|
||||
|
||||
// TODO(satok): Implement shape_inference
|
||||
REGISTER_OP("RemoteFusedGraphExecute")
|
||||
.Input("values: M * T")
|
||||
.Output("output: N * U")
|
||||
.Attr("M: int >= 0")
|
||||
.Attr("N: int >= 0")
|
||||
.Attr("T: type")
|
||||
.Attr("U: type")
|
||||
.Attr("serialized_graph_transfer_info: string")
|
||||
.Input("inputs: Tinputs")
|
||||
.Output("outputs: Toutputs")
|
||||
.Attr("Tinputs: list(type) >= 0")
|
||||
.Attr("Toutputs: list(type) >= 0")
|
||||
.Attr("serialized_remote_fused_graph_execute_info: string")
|
||||
.SetShapeFn(shape_inference::UnknownShape)
|
||||
.Doc(R"doc(
|
||||
Execute a sub graph on a remote processor transferred by GraphTransferer.
|
||||
|
@ -26,21 +26,33 @@ namespace tensorflow {
|
||||
|
||||
TEST(RemoteFusedGraphOpsTest, RemoteFusedGraphExecute_ShapeFn) {
|
||||
ShapeInferenceTestOp op("RemoteFusedGraphExecute");
|
||||
auto set_n = [&op](int input_count, int output_count) {
|
||||
auto set_n = [&op](int input1_count, int input2_count, int output_count) {
|
||||
std::vector<NodeDefBuilder::NodeOut> src_list;
|
||||
for (int i = 0; i < input_count; ++i) {
|
||||
DataTypeVector input_types;
|
||||
for (int i = 0; i < input1_count; ++i) {
|
||||
src_list.emplace_back("a", 0, DT_FLOAT);
|
||||
input_types.emplace_back(DT_FLOAT);
|
||||
}
|
||||
TF_ASSERT_OK(NodeDefBuilder("test", "RemoteFusedGraphExecute")
|
||||
.Input(src_list)
|
||||
.Attr("M", input_count)
|
||||
.Attr("N", output_count)
|
||||
.Attr("T", DT_FLOAT)
|
||||
.Attr("U", DT_FLOAT)
|
||||
.Finalize(&op.node_def));
|
||||
for (int i = 0; i < input2_count; ++i) {
|
||||
src_list.emplace_back("b", 0, DT_INT32);
|
||||
input_types.emplace_back(DT_INT32);
|
||||
}
|
||||
DataTypeVector output_types;
|
||||
for (int i = 0; i < output_count; ++i) {
|
||||
output_types.emplace_back(DT_FLOAT);
|
||||
}
|
||||
NodeDefBuilder builder = NodeDefBuilder("test", "RemoteFusedGraphExecute")
|
||||
.Input(src_list)
|
||||
.Attr("Tinputs", input_types)
|
||||
.Attr("Toutputs", output_types);
|
||||
TF_ASSERT_OK(builder.Finalize(&op.node_def));
|
||||
};
|
||||
set_n(4, 2);
|
||||
set_n(4, 0, 2);
|
||||
INFER_OK(op, "?;?;?;?", "?;?"); // output rank unknown
|
||||
|
||||
set_n(4, 3, 3);
|
||||
INFER_OK(op, "?;?;?;?;?;?;?", "?;?;?"); // output rank unknown
|
||||
|
||||
// TODO(satok): Implement shape inference and do its test here
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user