Unify logic for deducing memory types for cases with and without kernel definition
This is not an NFC change, since the new version also starts to respect `_input_hostmem` and `_output_hostmem` attributes for the function call case, which would be necessary for subsequent changes. PiperOrigin-RevId: 303873485 Change-Id: Ib366976c2121f0b5066236e52d1776d7a4535b4e
This commit is contained in:
parent
877d642a1a
commit
292fd8cc9c
@ -104,60 +104,51 @@ Status MemoryTypesForNode(const OpRegistryInterface* op_registry,
|
||||
return it != ndef.attr().end() && it->second.b();
|
||||
}();
|
||||
|
||||
// For functions (which have no KernelDef) and their gradients, we can only
|
||||
// best-effort derive the memory type from the data type. For now, we assume
|
||||
// int32 is always on host memory and other types are always on device memory.
|
||||
// TODO(zhifengc,phawkins): We should do type inference over function bodies
|
||||
// to derive the correct input/output memory types. We should also split
|
||||
// host-memory and non host-memory arguments into separate type lists.
|
||||
if (!status.ok() || IsFunctionCallOp(ndef.op())) {
|
||||
if (device_type.type_string() == "TPU" || has_xla_compile) {
|
||||
// Here we assume that if tf.function() is called within
|
||||
// "with tf.device('/device:TPU:0')", the whole function will be compiled
|
||||
// and executed on TPU. This is true today, but when we implement auto
|
||||
// clustering on function body, this will no longer be true. For example,
|
||||
// we might want to place string arguments on host.
|
||||
for (const auto& t : inp_dtypes)
|
||||
inp_mtypes->push_back(MTypeFromDTypeIntsOnDevice(t));
|
||||
for (const auto& t : out_dtypes)
|
||||
out_mtypes->push_back(MTypeFromDTypeIntsOnDevice(t));
|
||||
} else {
|
||||
for (const auto& t : inp_dtypes) inp_mtypes->push_back(MTypeFromDType(t));
|
||||
for (const auto& t : out_dtypes) out_mtypes->push_back(MTypeFromDType(t));
|
||||
bool has_kernel_def = status.ok() && !IsFunctionCallOp(ndef.op());
|
||||
auto host_memory_required = [&](const DataType& dt) {
|
||||
bool int32_on_device =
|
||||
has_kernel_def || device_type.type_string() == "TPU" || has_xla_compile;
|
||||
return DataTypeAlwaysOnHost(dt) || (dt == DT_INT32 && !int32_on_device);
|
||||
};
|
||||
|
||||
if (has_kernel_def) {
|
||||
// Gets the input/output names and their corresponding endpoint ranges.
|
||||
NameRangeMap inp_names;
|
||||
NameRangeMap out_names;
|
||||
TF_RETURN_IF_ERROR(
|
||||
NameRangesForNode(ndef, *op_def, &inp_names, &out_names));
|
||||
|
||||
// Now that we know the size, fill with the default 'DEVICE_MEMORY'.
|
||||
inp_mtypes->resize(GetTotal(inp_names), DEVICE_MEMORY);
|
||||
out_mtypes->resize(GetTotal(out_names), DEVICE_MEMORY);
|
||||
|
||||
// Fills in host memory types based on the kernel def.
|
||||
const auto& from_proto = kdef->host_memory_arg();
|
||||
std::vector<string> host_memory_args(from_proto.begin(), from_proto.end());
|
||||
MemoryTypesHelper(inp_names, &host_memory_args, inp_mtypes);
|
||||
MemoryTypesHelper(out_names, &host_memory_args, out_mtypes);
|
||||
if (!host_memory_args.empty()) {
|
||||
return errors::InvalidArgument(
|
||||
"HostMemory args '", absl::StrJoin(host_memory_args, "', '"),
|
||||
"' not found in OpDef: ", SummarizeOpDef(*op_def));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Gets the input/output names and their corresponding endpoint ranges.
|
||||
NameRangeMap inp_names;
|
||||
NameRangeMap out_names;
|
||||
TF_RETURN_IF_ERROR(NameRangesForNode(ndef, *op_def, &inp_names, &out_names));
|
||||
|
||||
// Now that we know the size, fill with the default 'DEVICE_MEMORY'.
|
||||
inp_mtypes->resize(GetTotal(inp_names), DEVICE_MEMORY);
|
||||
out_mtypes->resize(GetTotal(out_names), DEVICE_MEMORY);
|
||||
|
||||
// Fills in host memory types based on the kernel def.
|
||||
const auto& from_proto = kdef->host_memory_arg();
|
||||
std::vector<string> host_memory_args(from_proto.begin(), from_proto.end());
|
||||
MemoryTypesHelper(inp_names, &host_memory_args, inp_mtypes);
|
||||
MemoryTypesHelper(out_names, &host_memory_args, out_mtypes);
|
||||
if (!host_memory_args.empty()) {
|
||||
return errors::InvalidArgument(
|
||||
"HostMemory args '", absl::StrJoin(host_memory_args, "', '"),
|
||||
"' not found in OpDef: ", SummarizeOpDef(*op_def));
|
||||
} else {
|
||||
// Set all the datatype to DEVICE_MEMORY by default, later on change it to
|
||||
// HOST_MEMORY where it is required by the datatype.
|
||||
inp_mtypes->resize(inp_dtypes.size(), DEVICE_MEMORY);
|
||||
out_mtypes->resize(out_dtypes.size(), DEVICE_MEMORY);
|
||||
}
|
||||
CHECK_LE(inp_mtypes->size(), inp_dtypes.size());
|
||||
CHECK_LE(out_mtypes->size(), out_dtypes.size());
|
||||
|
||||
// Mark e.g. all resource and string types as host memory.
|
||||
for (int i = 0; i < inp_mtypes->size(); ++i) {
|
||||
if (DataTypeAlwaysOnHost(inp_dtypes[i])) {
|
||||
if (host_memory_required(inp_dtypes[i])) {
|
||||
(*inp_mtypes)[i] = HOST_MEMORY;
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < out_mtypes->size(); ++i) {
|
||||
if (DataTypeAlwaysOnHost(out_dtypes[i])) {
|
||||
if (host_memory_required(out_dtypes[i])) {
|
||||
(*out_mtypes)[i] = HOST_MEMORY;
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user