If _Retval is on TPU, use MTypeFromDTypeIntsOnDevice() instead of MTypeFromDType().

PiperOrigin-RevId: 246941558
This commit is contained in:
Tong Shen 2019-05-06 19:06:07 -07:00 committed by TensorFlower Gardener
parent b912370109
commit cc91dc6aaa
4 changed files with 18 additions and 10 deletions

View File

@ -70,7 +70,7 @@ Status PartitionFunctionGraph(
}
Status UpdateArgAndRetvalMetadata(
Graph* subgraph, std::vector<int>* arg_indices,
Graph* subgraph, const string& device_type, std::vector<int>* arg_indices,
std::vector<int>* ret_indices,
std::vector<AllocatorAttributes>* arg_alloc_attrs,
std::vector<AllocatorAttributes>* ret_alloc_attrs) {
@ -101,7 +101,9 @@ Status UpdateArgAndRetvalMetadata(
TF_RETURN_IF_ERROR(arg->attrs().Find("T", &attr_value));
AllocatorAttributes alloc_attr;
DataType type = attr_value->type();
if (MTypeFromDType(type) == HOST_MEMORY) {
MemoryType mtype = (device_type == "TPU") ? MTypeFromDTypeIntsOnDevice(type)
: MTypeFromDType(type);
if (mtype == HOST_MEMORY) {
alloc_attr.set_on_host(true);
}
arg_alloc_attrs->push_back(alloc_attr);
@ -112,7 +114,9 @@ Status UpdateArgAndRetvalMetadata(
TF_RETURN_IF_ERROR(ret->attrs().Find("T", &attr_value));
AllocatorAttributes alloc_attr;
DataType type = attr_value->type();
if (MTypeFromDType(type) == HOST_MEMORY) {
MemoryType mtype = (device_type == "TPU") ? MTypeFromDTypeIntsOnDevice(type)
: MTypeFromDType(type);
if (mtype == HOST_MEMORY) {
alloc_attr.set_on_host(true);
}
ret_alloc_attrs->push_back(alloc_attr);

View File

@ -57,7 +57,7 @@ Status PartitionFunctionGraph(
// (3) records which `Arg` and `Retval` nodes live in host memory in
// `*_alloc_attrs`.
Status UpdateArgAndRetvalMetadata(
Graph* subgraph, std::vector<int>* arg_indices,
Graph* subgraph, const string& device_type, std::vector<int>* arg_indices,
std::vector<int>* ret_indices,
std::vector<AllocatorAttributes>* arg_alloc_attrs,
std::vector<AllocatorAttributes>* ret_alloc_attrs);

View File

@ -183,9 +183,11 @@ TEST_F(PartitioningUtilsTest, UpdateArgsAndRets) {
std::vector<AllocatorAttributes> arg_alloc_attrs;
std::vector<AllocatorAttributes> ret_alloc_attrs;
Status status =
UpdateArgAndRetvalMetadata(graph.get(), &arg_indices, &ret_indices,
&arg_alloc_attrs, &ret_alloc_attrs);
string device_type = "CPU";
Status status = UpdateArgAndRetvalMetadata(
graph.get(), device_type, &arg_indices, &ret_indices, &arg_alloc_attrs,
&ret_alloc_attrs);
ASSERT_TRUE(status.ok()) << status.ToString();
CheckIndices({3}, arg_indices);

View File

@ -739,17 +739,19 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
// TODO(iga): Fail gracefully if the set of devices corresponds
// to more than one address space.
const string& target = pair.first;
FunctionLibraryRuntime* target_flr = GetFLR(target);
const string& device_type = target_flr->device()->device_type();
Graph* subgraph = pair.second.get();
ComponentFunctionData* comp_data = &data->glue_[target];
TF_RETURN_IF_ERROR(UpdateArgAndRetvalMetadata(
subgraph, &comp_data->arg_indices_, &comp_data->ret_indices_,
&comp_data->arg_alloc_attrs_, &comp_data->ret_alloc_attrs_));
subgraph, device_type, &comp_data->arg_indices_,
&comp_data->ret_indices_, &comp_data->arg_alloc_attrs_,
&comp_data->ret_alloc_attrs_));
FunctionDef shard;
string unique_name = name_generator.GetName();
TF_RETURN_IF_ERROR(
GraphToFunctionDef(*subgraph, unique_name, control_ret, &shard));
FunctionLibraryRuntime* target_flr = GetFLR(target);
TF_RETURN_IF_ERROR(data->lib_def_.AddFunctionDef(shard));
FunctionLibraryRuntime::InstantiateOptions opts;
opts.executor_type = options.executor_type;