diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index f6d376e01f4..b3655dcba63 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -435,6 +435,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core/common_runtime:core_cpu_internal", + "//tensorflow/core/common_runtime/eager:tensor_handle", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", diff --git a/tensorflow/compiler/jit/get_compiler_ir.cc b/tensorflow/compiler/jit/get_compiler_ir.cc index 21c3194eadc..08b3bea1084 100644 --- a/tensorflow/compiler/jit/get_compiler_ir.cc +++ b/tensorflow/compiler/jit/get_compiler_ir.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_platform_info.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" +#include "tensorflow/core/common_runtime/eager/tensor_handle.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/lib/core/status.h" @@ -47,8 +48,8 @@ static xla::StatusOr GetLocalExecutable( xla::StatusOr GetCompilerIr( IrExportStage stage, ProcessFunctionLibraryRuntime* pflr, - absl::string_view func_name, Device* dev, - absl::Span inputs) { + absl::string_view func_name, Device* dev, EagerContext* context, + absl::Span inputs_handles) { NameAttrList function; function.set_name(std::string{func_name}); @@ -65,6 +66,25 @@ xla::StatusOr GetCompilerIr( GetInputMemoryTypes(fbody, constant_arg_indices, resource_arg_indices); MemoryTypeVector output_memory_types = GetOutputMemoryTypes(fbody); + std::deque inputs_storage; + std::vector inputs; + inputs.reserve(inputs_handles.size()); + for (int i = 0; i < inputs_handles.size(); i++) { + const TensorHandle* th = inputs_handles[i]; + const Tensor* t; + // Handle owns the tensor. + TF_RETURN_IF_ERROR(th->Tensor(&t)); + if (absl::c_binary_search(constant_arg_indices, i)) { + // Need to make sure it's on the host. + inputs_storage.emplace_back(t->dtype(), t->shape()); + TF_RETURN_IF_ERROR( + th->CopyToDevice(*context, /*d=*/nullptr, &inputs_storage.back())); + inputs.push_back(&inputs_storage.back()); + } else { + inputs.push_back(t); + } + } + std::vector variable_infos; TF_RETURN_IF_ERROR(GetVariableInfosFromInputs( rmgr, dev, inputs, resource_arg_indices, &variable_infos)); diff --git a/tensorflow/compiler/jit/get_compiler_ir.h b/tensorflow/compiler/jit/get_compiler_ir.h index 0ec69fd9b38..0a0a1a44271 100644 --- a/tensorflow/compiler/jit/get_compiler_ir.h +++ b/tensorflow/compiler/jit/get_compiler_ir.h @@ -24,6 +24,8 @@ namespace tensorflow { class ProcessFunctionLibraryRuntime; class Device; class Tensor; +class TensorHandle; +class EagerContext; enum class IrExportStage { HLO, OPTIMIZED_HLO, OPTIMIZED_HLO_DOT }; @@ -31,8 +33,8 @@ enum class IrExportStage { HLO, OPTIMIZED_HLO, OPTIMIZED_HLO_DOT }; // `runtime` on a device `dev` with given `inputs`. xla::StatusOr GetCompilerIr( IrExportStage stage, ProcessFunctionLibraryRuntime* pflr, - absl::string_view func_name, Device* dev, - absl::Span inputs); + absl::string_view func_name, Device* dev, EagerContext* context, + absl::Span inputs); } // namespace tensorflow diff --git a/tensorflow/python/eager/def_function_xla_jit_test.py b/tensorflow/python/eager/def_function_xla_jit_test.py index 306943a7522..16d4875f03c 100644 --- a/tensorflow/python/eager/def_function_xla_jit_test.py +++ b/tensorflow/python/eager/def_function_xla_jit_test.py @@ -643,6 +643,19 @@ class DefFunctionTest(xla_test.XLATestCase): self.assertIn('tuple', f.experimental_get_compiler_ir(l)()) + def testConstantOnWrongDevice(self): + with ops.device('device:{}:0'.format(self.device)): + + s = math_ops.cast(random_ops.random_normal([2]), dtypes.int32) + l = random_ops.random_normal([s[0] * s[1]]) + + @def_function.function(experimental_compile=True) + def f(l): + return array_ops.reshape(l, s) + + self.assertIn('tuple', + f.experimental_get_compiler_ir(l)()) + if __name__ == '__main__': ops.enable_eager_execution() diff --git a/tensorflow/python/tfe_wrapper.cc b/tensorflow/python/tfe_wrapper.cc index fb89f3f54be..36165deeaad 100644 --- a/tensorflow/python/tfe_wrapper.cc +++ b/tensorflow/python/tfe_wrapper.cc @@ -314,19 +314,10 @@ static std::string TFE_GetCompilerIr(py::handle& ctx, TFE_InputTensorHandles handles = InputTFE_InputTensorHandles(inputs); - std::vector input_tensors; + std::vector input_handles; for (TFE_TensorHandle* tensor_handle : handles) { AbstractTensorHandle* abstract_tensor_handle = unwrap(tensor_handle); - TensorHandle* th = TensorHandleFromInterface(abstract_tensor_handle); - - const Tensor* t; - Status st = th->Tensor(&t); - if (!st.ok()) { - ThrowValueError( - absl::StrFormat("Could not resolve tensor: '%s'", st.error_message()) - .c_str()); - } - input_tensors.push_back(t); + input_handles.push_back(TensorHandleFromInterface(abstract_tensor_handle)); } DeviceNameUtils::ParsedName input_device_name; @@ -347,7 +338,7 @@ static std::string TFE_GetCompilerIr(py::handle& ctx, xla::StatusOr hlo_text = GetCompilerIr(selected_stage, context->pflr(), concrete_function_name, - *selected_device, input_tensors); + *selected_device, context, input_handles); if (!hlo_text.ok()) { ThrowValueError(absl::StrFormat("Failed getting HLO text: '%s'",