[TF2XLA] Explicitly copy must-be-constants to host in get_compiler_ir

PiperOrigin-RevId: 334878548
Change-Id: Icfb2a4806d35c4974887aa19fc55b382a8a7ff5b
This commit is contained in:
George Karpenkov 2020-10-01 12:56:35 -07:00 committed by TensorFlower Gardener
parent 55cadd046b
commit 8991d372bf
5 changed files with 43 additions and 16 deletions

View File

@ -435,6 +435,7 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/common_runtime:core_cpu_internal",
"//tensorflow/core/common_runtime/eager:tensor_handle",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_platform_info.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/lib/core/status.h"
@ -47,8 +48,8 @@ static xla::StatusOr<xla::LocalExecutable*> GetLocalExecutable(
xla::StatusOr<std::string> GetCompilerIr(
IrExportStage stage, ProcessFunctionLibraryRuntime* pflr,
absl::string_view func_name, Device* dev,
absl::Span<const Tensor* const> inputs) {
absl::string_view func_name, Device* dev, EagerContext* context,
absl::Span<const TensorHandle* const> inputs_handles) {
NameAttrList function;
function.set_name(std::string{func_name});
@ -65,6 +66,25 @@ xla::StatusOr<std::string> GetCompilerIr(
GetInputMemoryTypes(fbody, constant_arg_indices, resource_arg_indices);
MemoryTypeVector output_memory_types = GetOutputMemoryTypes(fbody);
std::deque<Tensor> inputs_storage;
std::vector<const Tensor*> inputs;
inputs.reserve(inputs_handles.size());
for (int i = 0; i < inputs_handles.size(); i++) {
const TensorHandle* th = inputs_handles[i];
const Tensor* t;
// Handle owns the tensor.
TF_RETURN_IF_ERROR(th->Tensor(&t));
if (absl::c_binary_search(constant_arg_indices, i)) {
// Need to make sure it's on the host.
inputs_storage.emplace_back(t->dtype(), t->shape());
TF_RETURN_IF_ERROR(
th->CopyToDevice(*context, /*d=*/nullptr, &inputs_storage.back()));
inputs.push_back(&inputs_storage.back());
} else {
inputs.push_back(t);
}
}
std::vector<VariableInfo> variable_infos;
TF_RETURN_IF_ERROR(GetVariableInfosFromInputs(
rmgr, dev, inputs, resource_arg_indices, &variable_infos));

View File

@ -24,6 +24,8 @@ namespace tensorflow {
class ProcessFunctionLibraryRuntime;
class Device;
class Tensor;
class TensorHandle;
class EagerContext;
enum class IrExportStage { HLO, OPTIMIZED_HLO, OPTIMIZED_HLO_DOT };
@ -31,8 +33,8 @@ enum class IrExportStage { HLO, OPTIMIZED_HLO, OPTIMIZED_HLO_DOT };
// `runtime` on a device `dev` with given `inputs`.
xla::StatusOr<std::string> GetCompilerIr(
IrExportStage stage, ProcessFunctionLibraryRuntime* pflr,
absl::string_view func_name, Device* dev,
absl::Span<const Tensor* const> inputs);
absl::string_view func_name, Device* dev, EagerContext* context,
absl::Span<const TensorHandle* const> inputs);
} // namespace tensorflow

View File

@ -643,6 +643,19 @@ class DefFunctionTest(xla_test.XLATestCase):
self.assertIn('tuple',
f.experimental_get_compiler_ir(l)())
def testConstantOnWrongDevice(self):
with ops.device('device:{}:0'.format(self.device)):
s = math_ops.cast(random_ops.random_normal([2]), dtypes.int32)
l = random_ops.random_normal([s[0] * s[1]])
@def_function.function(experimental_compile=True)
def f(l):
return array_ops.reshape(l, s)
self.assertIn('tuple',
f.experimental_get_compiler_ir(l)())
if __name__ == '__main__':
ops.enable_eager_execution()

View File

@ -314,19 +314,10 @@ static std::string TFE_GetCompilerIr(py::handle& ctx,
TFE_InputTensorHandles handles = InputTFE_InputTensorHandles(inputs);
std::vector<const Tensor*> input_tensors;
std::vector<const TensorHandle*> input_handles;
for (TFE_TensorHandle* tensor_handle : handles) {
AbstractTensorHandle* abstract_tensor_handle = unwrap(tensor_handle);
TensorHandle* th = TensorHandleFromInterface(abstract_tensor_handle);
const Tensor* t;
Status st = th->Tensor(&t);
if (!st.ok()) {
ThrowValueError(
absl::StrFormat("Could not resolve tensor: '%s'", st.error_message())
.c_str());
}
input_tensors.push_back(t);
input_handles.push_back(TensorHandleFromInterface(abstract_tensor_handle));
}
DeviceNameUtils::ParsedName input_device_name;
@ -347,7 +338,7 @@ static std::string TFE_GetCompilerIr(py::handle& ctx,
xla::StatusOr<std::string> hlo_text =
GetCompilerIr(selected_stage, context->pflr(), concrete_function_name,
*selected_device, input_tensors);
*selected_device, context, input_handles);
if (!hlo_text.ok()) {
ThrowValueError(absl::StrFormat("Failed getting HLO text: '%s'",