If _Retval is on TPU, use MTypeFromDTypeIntsOnDevice() instead of MTypeFromDType().
PiperOrigin-RevId: 246941558
This commit is contained in:
parent
b912370109
commit
cc91dc6aaa
@ -70,7 +70,7 @@ Status PartitionFunctionGraph(
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status UpdateArgAndRetvalMetadata(
|
Status UpdateArgAndRetvalMetadata(
|
||||||
Graph* subgraph, std::vector<int>* arg_indices,
|
Graph* subgraph, const string& device_type, std::vector<int>* arg_indices,
|
||||||
std::vector<int>* ret_indices,
|
std::vector<int>* ret_indices,
|
||||||
std::vector<AllocatorAttributes>* arg_alloc_attrs,
|
std::vector<AllocatorAttributes>* arg_alloc_attrs,
|
||||||
std::vector<AllocatorAttributes>* ret_alloc_attrs) {
|
std::vector<AllocatorAttributes>* ret_alloc_attrs) {
|
||||||
@ -101,7 +101,9 @@ Status UpdateArgAndRetvalMetadata(
|
|||||||
TF_RETURN_IF_ERROR(arg->attrs().Find("T", &attr_value));
|
TF_RETURN_IF_ERROR(arg->attrs().Find("T", &attr_value));
|
||||||
AllocatorAttributes alloc_attr;
|
AllocatorAttributes alloc_attr;
|
||||||
DataType type = attr_value->type();
|
DataType type = attr_value->type();
|
||||||
if (MTypeFromDType(type) == HOST_MEMORY) {
|
MemoryType mtype = (device_type == "TPU") ? MTypeFromDTypeIntsOnDevice(type)
|
||||||
|
: MTypeFromDType(type);
|
||||||
|
if (mtype == HOST_MEMORY) {
|
||||||
alloc_attr.set_on_host(true);
|
alloc_attr.set_on_host(true);
|
||||||
}
|
}
|
||||||
arg_alloc_attrs->push_back(alloc_attr);
|
arg_alloc_attrs->push_back(alloc_attr);
|
||||||
@ -112,7 +114,9 @@ Status UpdateArgAndRetvalMetadata(
|
|||||||
TF_RETURN_IF_ERROR(ret->attrs().Find("T", &attr_value));
|
TF_RETURN_IF_ERROR(ret->attrs().Find("T", &attr_value));
|
||||||
AllocatorAttributes alloc_attr;
|
AllocatorAttributes alloc_attr;
|
||||||
DataType type = attr_value->type();
|
DataType type = attr_value->type();
|
||||||
if (MTypeFromDType(type) == HOST_MEMORY) {
|
MemoryType mtype = (device_type == "TPU") ? MTypeFromDTypeIntsOnDevice(type)
|
||||||
|
: MTypeFromDType(type);
|
||||||
|
if (mtype == HOST_MEMORY) {
|
||||||
alloc_attr.set_on_host(true);
|
alloc_attr.set_on_host(true);
|
||||||
}
|
}
|
||||||
ret_alloc_attrs->push_back(alloc_attr);
|
ret_alloc_attrs->push_back(alloc_attr);
|
||||||
|
@ -57,7 +57,7 @@ Status PartitionFunctionGraph(
|
|||||||
// (3) records which `Arg` and `Retval` nodes live in host memory in
|
// (3) records which `Arg` and `Retval` nodes live in host memory in
|
||||||
// `*_alloc_attrs`.
|
// `*_alloc_attrs`.
|
||||||
Status UpdateArgAndRetvalMetadata(
|
Status UpdateArgAndRetvalMetadata(
|
||||||
Graph* subgraph, std::vector<int>* arg_indices,
|
Graph* subgraph, const string& device_type, std::vector<int>* arg_indices,
|
||||||
std::vector<int>* ret_indices,
|
std::vector<int>* ret_indices,
|
||||||
std::vector<AllocatorAttributes>* arg_alloc_attrs,
|
std::vector<AllocatorAttributes>* arg_alloc_attrs,
|
||||||
std::vector<AllocatorAttributes>* ret_alloc_attrs);
|
std::vector<AllocatorAttributes>* ret_alloc_attrs);
|
||||||
|
@ -183,9 +183,11 @@ TEST_F(PartitioningUtilsTest, UpdateArgsAndRets) {
|
|||||||
std::vector<AllocatorAttributes> arg_alloc_attrs;
|
std::vector<AllocatorAttributes> arg_alloc_attrs;
|
||||||
std::vector<AllocatorAttributes> ret_alloc_attrs;
|
std::vector<AllocatorAttributes> ret_alloc_attrs;
|
||||||
|
|
||||||
Status status =
|
string device_type = "CPU";
|
||||||
UpdateArgAndRetvalMetadata(graph.get(), &arg_indices, &ret_indices,
|
|
||||||
&arg_alloc_attrs, &ret_alloc_attrs);
|
Status status = UpdateArgAndRetvalMetadata(
|
||||||
|
graph.get(), device_type, &arg_indices, &ret_indices, &arg_alloc_attrs,
|
||||||
|
&ret_alloc_attrs);
|
||||||
ASSERT_TRUE(status.ok()) << status.ToString();
|
ASSERT_TRUE(status.ok()) << status.ToString();
|
||||||
|
|
||||||
CheckIndices({3}, arg_indices);
|
CheckIndices({3}, arg_indices);
|
||||||
|
@ -739,17 +739,19 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
|
|||||||
// TODO(iga): Fail gracefully if the set of devices corresponds
|
// TODO(iga): Fail gracefully if the set of devices corresponds
|
||||||
// to more than one address space.
|
// to more than one address space.
|
||||||
const string& target = pair.first;
|
const string& target = pair.first;
|
||||||
|
FunctionLibraryRuntime* target_flr = GetFLR(target);
|
||||||
|
const string& device_type = target_flr->device()->device_type();
|
||||||
Graph* subgraph = pair.second.get();
|
Graph* subgraph = pair.second.get();
|
||||||
|
|
||||||
ComponentFunctionData* comp_data = &data->glue_[target];
|
ComponentFunctionData* comp_data = &data->glue_[target];
|
||||||
TF_RETURN_IF_ERROR(UpdateArgAndRetvalMetadata(
|
TF_RETURN_IF_ERROR(UpdateArgAndRetvalMetadata(
|
||||||
subgraph, &comp_data->arg_indices_, &comp_data->ret_indices_,
|
subgraph, device_type, &comp_data->arg_indices_,
|
||||||
&comp_data->arg_alloc_attrs_, &comp_data->ret_alloc_attrs_));
|
&comp_data->ret_indices_, &comp_data->arg_alloc_attrs_,
|
||||||
|
&comp_data->ret_alloc_attrs_));
|
||||||
FunctionDef shard;
|
FunctionDef shard;
|
||||||
string unique_name = name_generator.GetName();
|
string unique_name = name_generator.GetName();
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
GraphToFunctionDef(*subgraph, unique_name, control_ret, &shard));
|
GraphToFunctionDef(*subgraph, unique_name, control_ret, &shard));
|
||||||
FunctionLibraryRuntime* target_flr = GetFLR(target);
|
|
||||||
TF_RETURN_IF_ERROR(data->lib_def_.AddFunctionDef(shard));
|
TF_RETURN_IF_ERROR(data->lib_def_.AddFunctionDef(shard));
|
||||||
FunctionLibraryRuntime::InstantiateOptions opts;
|
FunctionLibraryRuntime::InstantiateOptions opts;
|
||||||
opts.executor_type = options.executor_type;
|
opts.executor_type = options.executor_type;
|
||||||
|
Loading…
Reference in New Issue
Block a user