From 292fd8cc9c67143dfe11be5454766881ff264680 Mon Sep 17 00:00:00 2001 From: George Karpenkov <cheshire@google.com> Date: Mon, 30 Mar 2020 19:25:28 -0700 Subject: [PATCH] Unify logic for deducing memory types for cases with and without kernel definition This is not an NFC change, since the new version also starts to respect `_input_hostmem` and `_output_hostmem` attributes for the function call case, which would be necessary for subsequent changes. PiperOrigin-RevId: 303873485 Change-Id: Ib366976c2121f0b5066236e52d1776d7a4535b4e --- tensorflow/core/framework/memory_types.cc | 77 ++++++++++------------- 1 file changed, 34 insertions(+), 43 deletions(-) diff --git a/tensorflow/core/framework/memory_types.cc b/tensorflow/core/framework/memory_types.cc index 5393b162e80..d27ef1da61d 100644 --- a/tensorflow/core/framework/memory_types.cc +++ b/tensorflow/core/framework/memory_types.cc @@ -104,60 +104,51 @@ Status MemoryTypesForNode(const OpRegistryInterface* op_registry, return it != ndef.attr().end() && it->second.b(); }(); - // For functions (which have no KernelDef) and their gradients, we can only - // best-effort derive the memory type from the data type. For now, we assume - // int32 is always on host memory and other types are always on device memory. - // TODO(zhifengc,phawkins): We should do type inference over function bodies - // to derive the correct input/output memory types. We should also split - // host-memory and non host-memory arguments into separate type lists. - if (!status.ok() || IsFunctionCallOp(ndef.op())) { - if (device_type.type_string() == "TPU" || has_xla_compile) { - // Here we assume that if tf.function() is called within - // "with tf.device('/device:TPU:0')", the whole function will be compiled - // and executed on TPU. This is true today, but when we implement auto - // clustering on function body, this will no longer be true. For example, - // we might want to place string arguments on host. - for (const auto& t : inp_dtypes) - inp_mtypes->push_back(MTypeFromDTypeIntsOnDevice(t)); - for (const auto& t : out_dtypes) - out_mtypes->push_back(MTypeFromDTypeIntsOnDevice(t)); - } else { - for (const auto& t : inp_dtypes) inp_mtypes->push_back(MTypeFromDType(t)); - for (const auto& t : out_dtypes) out_mtypes->push_back(MTypeFromDType(t)); + bool has_kernel_def = status.ok() && !IsFunctionCallOp(ndef.op()); + auto host_memory_required = [&](const DataType& dt) { + bool int32_on_device = + has_kernel_def || device_type.type_string() == "TPU" || has_xla_compile; + return DataTypeAlwaysOnHost(dt) || (dt == DT_INT32 && !int32_on_device); + }; + + if (has_kernel_def) { + // Gets the input/output names and their corresponding endpoint ranges. + NameRangeMap inp_names; + NameRangeMap out_names; + TF_RETURN_IF_ERROR( + NameRangesForNode(ndef, *op_def, &inp_names, &out_names)); + + // Now that we know the size, fill with the default 'DEVICE_MEMORY'. + inp_mtypes->resize(GetTotal(inp_names), DEVICE_MEMORY); + out_mtypes->resize(GetTotal(out_names), DEVICE_MEMORY); + + // Fills in host memory types based on the kernel def. + const auto& from_proto = kdef->host_memory_arg(); + std::vector<string> host_memory_args(from_proto.begin(), from_proto.end()); + MemoryTypesHelper(inp_names, &host_memory_args, inp_mtypes); + MemoryTypesHelper(out_names, &host_memory_args, out_mtypes); + if (!host_memory_args.empty()) { + return errors::InvalidArgument( + "HostMemory args '", absl::StrJoin(host_memory_args, "', '"), + "' not found in OpDef: ", SummarizeOpDef(*op_def)); } - return Status::OK(); - } - - // Gets the input/output names and their corresponding endpoint ranges. - NameRangeMap inp_names; - NameRangeMap out_names; - TF_RETURN_IF_ERROR(NameRangesForNode(ndef, *op_def, &inp_names, &out_names)); - - // Now that we know the size, fill with the default 'DEVICE_MEMORY'. - inp_mtypes->resize(GetTotal(inp_names), DEVICE_MEMORY); - out_mtypes->resize(GetTotal(out_names), DEVICE_MEMORY); - - // Fills in host memory types based on the kernel def. - const auto& from_proto = kdef->host_memory_arg(); - std::vector<string> host_memory_args(from_proto.begin(), from_proto.end()); - MemoryTypesHelper(inp_names, &host_memory_args, inp_mtypes); - MemoryTypesHelper(out_names, &host_memory_args, out_mtypes); - if (!host_memory_args.empty()) { - return errors::InvalidArgument( - "HostMemory args '", absl::StrJoin(host_memory_args, "', '"), - "' not found in OpDef: ", SummarizeOpDef(*op_def)); + } else { + // Set all the datatype to DEVICE_MEMORY by default, later on change it to + // HOST_MEMORY where it is required by the datatype. + inp_mtypes->resize(inp_dtypes.size(), DEVICE_MEMORY); + out_mtypes->resize(out_dtypes.size(), DEVICE_MEMORY); } CHECK_LE(inp_mtypes->size(), inp_dtypes.size()); CHECK_LE(out_mtypes->size(), out_dtypes.size()); // Mark e.g. all resource and string types as host memory. for (int i = 0; i < inp_mtypes->size(); ++i) { - if (DataTypeAlwaysOnHost(inp_dtypes[i])) { + if (host_memory_required(inp_dtypes[i])) { (*inp_mtypes)[i] = HOST_MEMORY; } } for (int i = 0; i < out_mtypes->size(); ++i) { - if (DataTypeAlwaysOnHost(out_dtypes[i])) { + if (host_memory_required(out_dtypes[i])) { (*out_mtypes)[i] = HOST_MEMORY; } }