From cc91dc6aaa3457254988dd865564a3a3394a9800 Mon Sep 17 00:00:00 2001 From: Tong Shen Date: Mon, 6 May 2019 19:06:07 -0700 Subject: [PATCH] If _Retval is on TPU, use MTypeFromDTypeIntsOnDevice() instead of MTypeFromDType(). PiperOrigin-RevId: 246941558 --- tensorflow/core/common_runtime/partitioning_utils.cc | 10 +++++++--- tensorflow/core/common_runtime/partitioning_utils.h | 2 +- .../core/common_runtime/partitioning_utils_test.cc | 8 +++++--- .../common_runtime/process_function_library_runtime.cc | 8 +++++--- 4 files changed, 18 insertions(+), 10 deletions(-) diff --git a/tensorflow/core/common_runtime/partitioning_utils.cc b/tensorflow/core/common_runtime/partitioning_utils.cc index d700040f8af..0fcf70d0491 100644 --- a/tensorflow/core/common_runtime/partitioning_utils.cc +++ b/tensorflow/core/common_runtime/partitioning_utils.cc @@ -70,7 +70,7 @@ Status PartitionFunctionGraph( } Status UpdateArgAndRetvalMetadata( - Graph* subgraph, std::vector* arg_indices, + Graph* subgraph, const string& device_type, std::vector* arg_indices, std::vector* ret_indices, std::vector* arg_alloc_attrs, std::vector* ret_alloc_attrs) { @@ -101,7 +101,9 @@ Status UpdateArgAndRetvalMetadata( TF_RETURN_IF_ERROR(arg->attrs().Find("T", &attr_value)); AllocatorAttributes alloc_attr; DataType type = attr_value->type(); - if (MTypeFromDType(type) == HOST_MEMORY) { + MemoryType mtype = (device_type == "TPU") ? MTypeFromDTypeIntsOnDevice(type) + : MTypeFromDType(type); + if (mtype == HOST_MEMORY) { alloc_attr.set_on_host(true); } arg_alloc_attrs->push_back(alloc_attr); @@ -112,7 +114,9 @@ Status UpdateArgAndRetvalMetadata( TF_RETURN_IF_ERROR(ret->attrs().Find("T", &attr_value)); AllocatorAttributes alloc_attr; DataType type = attr_value->type(); - if (MTypeFromDType(type) == HOST_MEMORY) { + MemoryType mtype = (device_type == "TPU") ? MTypeFromDTypeIntsOnDevice(type) + : MTypeFromDType(type); + if (mtype == HOST_MEMORY) { alloc_attr.set_on_host(true); } ret_alloc_attrs->push_back(alloc_attr); diff --git a/tensorflow/core/common_runtime/partitioning_utils.h b/tensorflow/core/common_runtime/partitioning_utils.h index c282647e702..1a6551ac858 100644 --- a/tensorflow/core/common_runtime/partitioning_utils.h +++ b/tensorflow/core/common_runtime/partitioning_utils.h @@ -57,7 +57,7 @@ Status PartitionFunctionGraph( // (3) records which `Arg` and `Retval` nodes live in host memory in // `*_alloc_attrs`. Status UpdateArgAndRetvalMetadata( - Graph* subgraph, std::vector* arg_indices, + Graph* subgraph, const string& device_type, std::vector* arg_indices, std::vector* ret_indices, std::vector* arg_alloc_attrs, std::vector* ret_alloc_attrs); diff --git a/tensorflow/core/common_runtime/partitioning_utils_test.cc b/tensorflow/core/common_runtime/partitioning_utils_test.cc index 2900c0a934f..9a3f3a68e34 100644 --- a/tensorflow/core/common_runtime/partitioning_utils_test.cc +++ b/tensorflow/core/common_runtime/partitioning_utils_test.cc @@ -183,9 +183,11 @@ TEST_F(PartitioningUtilsTest, UpdateArgsAndRets) { std::vector arg_alloc_attrs; std::vector ret_alloc_attrs; - Status status = - UpdateArgAndRetvalMetadata(graph.get(), &arg_indices, &ret_indices, - &arg_alloc_attrs, &ret_alloc_attrs); + string device_type = "CPU"; + + Status status = UpdateArgAndRetvalMetadata( + graph.get(), device_type, &arg_indices, &ret_indices, &arg_alloc_attrs, + &ret_alloc_attrs); ASSERT_TRUE(status.ok()) << status.ToString(); CheckIndices({3}, arg_indices); diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc index b3b75c6fe1b..a24757f33fa 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc @@ -739,17 +739,19 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice( // TODO(iga): Fail gracefully if the set of devices corresponds // to more than one address space. const string& target = pair.first; + FunctionLibraryRuntime* target_flr = GetFLR(target); + const string& device_type = target_flr->device()->device_type(); Graph* subgraph = pair.second.get(); ComponentFunctionData* comp_data = &data->glue_[target]; TF_RETURN_IF_ERROR(UpdateArgAndRetvalMetadata( - subgraph, &comp_data->arg_indices_, &comp_data->ret_indices_, - &comp_data->arg_alloc_attrs_, &comp_data->ret_alloc_attrs_)); + subgraph, device_type, &comp_data->arg_indices_, + &comp_data->ret_indices_, &comp_data->arg_alloc_attrs_, + &comp_data->ret_alloc_attrs_)); FunctionDef shard; string unique_name = name_generator.GetName(); TF_RETURN_IF_ERROR( GraphToFunctionDef(*subgraph, unique_name, control_ret, &shard)); - FunctionLibraryRuntime* target_flr = GetFLR(target); TF_RETURN_IF_ERROR(data->lib_def_.AddFunctionDef(shard)); FunctionLibraryRuntime::InstantiateOptions opts; opts.executor_type = options.executor_type;