Unify logic for deducing memory types for cases with and without kernel definition

This is not an NFC change, since the new version also starts to respect
`_input_hostmem` and `_output_hostmem` attributes for the function call case,
which would be necessary for subsequent changes.

PiperOrigin-RevId: 303873485
Change-Id: Ib366976c2121f0b5066236e52d1776d7a4535b4e
This commit is contained in:
George Karpenkov 2020-03-30 19:25:28 -07:00 committed by TensorFlower Gardener
parent 877d642a1a
commit 292fd8cc9c

View File

@ -104,60 +104,51 @@ Status MemoryTypesForNode(const OpRegistryInterface* op_registry,
return it != ndef.attr().end() && it->second.b();
}();
// For functions (which have no KernelDef) and their gradients, we can only
// best-effort derive the memory type from the data type. For now, we assume
// int32 is always on host memory and other types are always on device memory.
// TODO(zhifengc,phawkins): We should do type inference over function bodies
// to derive the correct input/output memory types. We should also split
// host-memory and non host-memory arguments into separate type lists.
if (!status.ok() || IsFunctionCallOp(ndef.op())) {
if (device_type.type_string() == "TPU" || has_xla_compile) {
// Here we assume that if tf.function() is called within
// "with tf.device('/device:TPU:0')", the whole function will be compiled
// and executed on TPU. This is true today, but when we implement auto
// clustering on function body, this will no longer be true. For example,
// we might want to place string arguments on host.
for (const auto& t : inp_dtypes)
inp_mtypes->push_back(MTypeFromDTypeIntsOnDevice(t));
for (const auto& t : out_dtypes)
out_mtypes->push_back(MTypeFromDTypeIntsOnDevice(t));
} else {
for (const auto& t : inp_dtypes) inp_mtypes->push_back(MTypeFromDType(t));
for (const auto& t : out_dtypes) out_mtypes->push_back(MTypeFromDType(t));
bool has_kernel_def = status.ok() && !IsFunctionCallOp(ndef.op());
auto host_memory_required = [&](const DataType& dt) {
bool int32_on_device =
has_kernel_def || device_type.type_string() == "TPU" || has_xla_compile;
return DataTypeAlwaysOnHost(dt) || (dt == DT_INT32 && !int32_on_device);
};
if (has_kernel_def) {
// Gets the input/output names and their corresponding endpoint ranges.
NameRangeMap inp_names;
NameRangeMap out_names;
TF_RETURN_IF_ERROR(
NameRangesForNode(ndef, *op_def, &inp_names, &out_names));
// Now that we know the size, fill with the default 'DEVICE_MEMORY'.
inp_mtypes->resize(GetTotal(inp_names), DEVICE_MEMORY);
out_mtypes->resize(GetTotal(out_names), DEVICE_MEMORY);
// Fills in host memory types based on the kernel def.
const auto& from_proto = kdef->host_memory_arg();
std::vector<string> host_memory_args(from_proto.begin(), from_proto.end());
MemoryTypesHelper(inp_names, &host_memory_args, inp_mtypes);
MemoryTypesHelper(out_names, &host_memory_args, out_mtypes);
if (!host_memory_args.empty()) {
return errors::InvalidArgument(
"HostMemory args '", absl::StrJoin(host_memory_args, "', '"),
"' not found in OpDef: ", SummarizeOpDef(*op_def));
}
return Status::OK();
}
// Gets the input/output names and their corresponding endpoint ranges.
NameRangeMap inp_names;
NameRangeMap out_names;
TF_RETURN_IF_ERROR(NameRangesForNode(ndef, *op_def, &inp_names, &out_names));
// Now that we know the size, fill with the default 'DEVICE_MEMORY'.
inp_mtypes->resize(GetTotal(inp_names), DEVICE_MEMORY);
out_mtypes->resize(GetTotal(out_names), DEVICE_MEMORY);
// Fills in host memory types based on the kernel def.
const auto& from_proto = kdef->host_memory_arg();
std::vector<string> host_memory_args(from_proto.begin(), from_proto.end());
MemoryTypesHelper(inp_names, &host_memory_args, inp_mtypes);
MemoryTypesHelper(out_names, &host_memory_args, out_mtypes);
if (!host_memory_args.empty()) {
return errors::InvalidArgument(
"HostMemory args '", absl::StrJoin(host_memory_args, "', '"),
"' not found in OpDef: ", SummarizeOpDef(*op_def));
} else {
// Set all the datatype to DEVICE_MEMORY by default, later on change it to
// HOST_MEMORY where it is required by the datatype.
inp_mtypes->resize(inp_dtypes.size(), DEVICE_MEMORY);
out_mtypes->resize(out_dtypes.size(), DEVICE_MEMORY);
}
CHECK_LE(inp_mtypes->size(), inp_dtypes.size());
CHECK_LE(out_mtypes->size(), out_dtypes.size());
// Mark e.g. all resource and string types as host memory.
for (int i = 0; i < inp_mtypes->size(); ++i) {
if (DataTypeAlwaysOnHost(inp_dtypes[i])) {
if (host_memory_required(inp_dtypes[i])) {
(*inp_mtypes)[i] = HOST_MEMORY;
}
}
for (int i = 0; i < out_mtypes->size(); ++i) {
if (DataTypeAlwaysOnHost(out_dtypes[i])) {
if (host_memory_required(out_dtypes[i])) {
(*out_mtypes)[i] = HOST_MEMORY;
}
}