From 292fd8cc9c67143dfe11be5454766881ff264680 Mon Sep 17 00:00:00 2001
From: George Karpenkov <cheshire@google.com>
Date: Mon, 30 Mar 2020 19:25:28 -0700
Subject: [PATCH] Unify logic for deducing memory types for cases with and
 without kernel definition

This is not an NFC change, since the new version also starts to respect
`_input_hostmem` and `_output_hostmem` attributes for the function call case,
which would be necessary for subsequent changes.

PiperOrigin-RevId: 303873485
Change-Id: Ib366976c2121f0b5066236e52d1776d7a4535b4e
---
 tensorflow/core/framework/memory_types.cc | 77 ++++++++++-------------
 1 file changed, 34 insertions(+), 43 deletions(-)

diff --git a/tensorflow/core/framework/memory_types.cc b/tensorflow/core/framework/memory_types.cc
index 5393b162e80..d27ef1da61d 100644
--- a/tensorflow/core/framework/memory_types.cc
+++ b/tensorflow/core/framework/memory_types.cc
@@ -104,60 +104,51 @@ Status MemoryTypesForNode(const OpRegistryInterface* op_registry,
     return it != ndef.attr().end() && it->second.b();
   }();
 
-  // For functions (which have no KernelDef) and their gradients, we can only
-  // best-effort derive the memory type from the data type. For now, we assume
-  // int32 is always on host memory and other types are always on device memory.
-  // TODO(zhifengc,phawkins): We should do type inference over function bodies
-  // to derive the correct input/output memory types. We should also split
-  // host-memory and non host-memory arguments into separate type lists.
-  if (!status.ok() || IsFunctionCallOp(ndef.op())) {
-    if (device_type.type_string() == "TPU" || has_xla_compile) {
-      // Here we assume that if tf.function() is called within
-      // "with tf.device('/device:TPU:0')", the whole function will be compiled
-      // and executed on TPU. This is true today, but when we implement auto
-      // clustering on function body, this will no longer be true. For example,
-      // we might want to place string arguments on host.
-      for (const auto& t : inp_dtypes)
-        inp_mtypes->push_back(MTypeFromDTypeIntsOnDevice(t));
-      for (const auto& t : out_dtypes)
-        out_mtypes->push_back(MTypeFromDTypeIntsOnDevice(t));
-    } else {
-      for (const auto& t : inp_dtypes) inp_mtypes->push_back(MTypeFromDType(t));
-      for (const auto& t : out_dtypes) out_mtypes->push_back(MTypeFromDType(t));
+  bool has_kernel_def = status.ok() && !IsFunctionCallOp(ndef.op());
+  auto host_memory_required = [&](const DataType& dt) {
+    bool int32_on_device =
+        has_kernel_def || device_type.type_string() == "TPU" || has_xla_compile;
+    return DataTypeAlwaysOnHost(dt) || (dt == DT_INT32 && !int32_on_device);
+  };
+
+  if (has_kernel_def) {
+    // Gets the input/output names and their corresponding endpoint ranges.
+    NameRangeMap inp_names;
+    NameRangeMap out_names;
+    TF_RETURN_IF_ERROR(
+        NameRangesForNode(ndef, *op_def, &inp_names, &out_names));
+
+    // Now that we know the size, fill with the default 'DEVICE_MEMORY'.
+    inp_mtypes->resize(GetTotal(inp_names), DEVICE_MEMORY);
+    out_mtypes->resize(GetTotal(out_names), DEVICE_MEMORY);
+
+    // Fills in host memory types based on the kernel def.
+    const auto& from_proto = kdef->host_memory_arg();
+    std::vector<string> host_memory_args(from_proto.begin(), from_proto.end());
+    MemoryTypesHelper(inp_names, &host_memory_args, inp_mtypes);
+    MemoryTypesHelper(out_names, &host_memory_args, out_mtypes);
+    if (!host_memory_args.empty()) {
+      return errors::InvalidArgument(
+          "HostMemory args '", absl::StrJoin(host_memory_args, "', '"),
+          "' not found in OpDef: ", SummarizeOpDef(*op_def));
     }
-    return Status::OK();
-  }
-
-  // Gets the input/output names and their corresponding endpoint ranges.
-  NameRangeMap inp_names;
-  NameRangeMap out_names;
-  TF_RETURN_IF_ERROR(NameRangesForNode(ndef, *op_def, &inp_names, &out_names));
-
-  // Now that we know the size, fill with the default 'DEVICE_MEMORY'.
-  inp_mtypes->resize(GetTotal(inp_names), DEVICE_MEMORY);
-  out_mtypes->resize(GetTotal(out_names), DEVICE_MEMORY);
-
-  // Fills in host memory types based on the kernel def.
-  const auto& from_proto = kdef->host_memory_arg();
-  std::vector<string> host_memory_args(from_proto.begin(), from_proto.end());
-  MemoryTypesHelper(inp_names, &host_memory_args, inp_mtypes);
-  MemoryTypesHelper(out_names, &host_memory_args, out_mtypes);
-  if (!host_memory_args.empty()) {
-    return errors::InvalidArgument(
-        "HostMemory args '", absl::StrJoin(host_memory_args, "', '"),
-        "' not found in OpDef: ", SummarizeOpDef(*op_def));
+  } else {
+    // Set all the datatype to DEVICE_MEMORY by default, later on change it to
+    // HOST_MEMORY where it is required by the datatype.
+    inp_mtypes->resize(inp_dtypes.size(), DEVICE_MEMORY);
+    out_mtypes->resize(out_dtypes.size(), DEVICE_MEMORY);
   }
   CHECK_LE(inp_mtypes->size(), inp_dtypes.size());
   CHECK_LE(out_mtypes->size(), out_dtypes.size());
 
   // Mark e.g. all resource and string types as host memory.
   for (int i = 0; i < inp_mtypes->size(); ++i) {
-    if (DataTypeAlwaysOnHost(inp_dtypes[i])) {
+    if (host_memory_required(inp_dtypes[i])) {
       (*inp_mtypes)[i] = HOST_MEMORY;
     }
   }
   for (int i = 0; i < out_mtypes->size(); ++i) {
-    if (DataTypeAlwaysOnHost(out_dtypes[i])) {
+    if (host_memory_required(out_dtypes[i])) {
       (*out_mtypes)[i] = HOST_MEMORY;
     }
   }