[TF2XLA] Explicitly copy must-be-constants to host in get_compiler_ir
PiperOrigin-RevId: 334878548 Change-Id: Icfb2a4806d35c4974887aa19fc55b382a8a7ff5b
This commit is contained in:
parent
55cadd046b
commit
8991d372bf
@ -435,6 +435,7 @@ cc_library(
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/common_runtime:core_cpu_internal",
|
||||
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/jit/xla_platform_info.h"
|
||||
#include "tensorflow/compiler/tf2xla/const_analysis.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
|
||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
@ -47,8 +48,8 @@ static xla::StatusOr<xla::LocalExecutable*> GetLocalExecutable(
|
||||
|
||||
xla::StatusOr<std::string> GetCompilerIr(
|
||||
IrExportStage stage, ProcessFunctionLibraryRuntime* pflr,
|
||||
absl::string_view func_name, Device* dev,
|
||||
absl::Span<const Tensor* const> inputs) {
|
||||
absl::string_view func_name, Device* dev, EagerContext* context,
|
||||
absl::Span<const TensorHandle* const> inputs_handles) {
|
||||
NameAttrList function;
|
||||
function.set_name(std::string{func_name});
|
||||
|
||||
@ -65,6 +66,25 @@ xla::StatusOr<std::string> GetCompilerIr(
|
||||
GetInputMemoryTypes(fbody, constant_arg_indices, resource_arg_indices);
|
||||
MemoryTypeVector output_memory_types = GetOutputMemoryTypes(fbody);
|
||||
|
||||
std::deque<Tensor> inputs_storage;
|
||||
std::vector<const Tensor*> inputs;
|
||||
inputs.reserve(inputs_handles.size());
|
||||
for (int i = 0; i < inputs_handles.size(); i++) {
|
||||
const TensorHandle* th = inputs_handles[i];
|
||||
const Tensor* t;
|
||||
// Handle owns the tensor.
|
||||
TF_RETURN_IF_ERROR(th->Tensor(&t));
|
||||
if (absl::c_binary_search(constant_arg_indices, i)) {
|
||||
// Need to make sure it's on the host.
|
||||
inputs_storage.emplace_back(t->dtype(), t->shape());
|
||||
TF_RETURN_IF_ERROR(
|
||||
th->CopyToDevice(*context, /*d=*/nullptr, &inputs_storage.back()));
|
||||
inputs.push_back(&inputs_storage.back());
|
||||
} else {
|
||||
inputs.push_back(t);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<VariableInfo> variable_infos;
|
||||
TF_RETURN_IF_ERROR(GetVariableInfosFromInputs(
|
||||
rmgr, dev, inputs, resource_arg_indices, &variable_infos));
|
||||
|
@ -24,6 +24,8 @@ namespace tensorflow {
|
||||
class ProcessFunctionLibraryRuntime;
|
||||
class Device;
|
||||
class Tensor;
|
||||
class TensorHandle;
|
||||
class EagerContext;
|
||||
|
||||
enum class IrExportStage { HLO, OPTIMIZED_HLO, OPTIMIZED_HLO_DOT };
|
||||
|
||||
@ -31,8 +33,8 @@ enum class IrExportStage { HLO, OPTIMIZED_HLO, OPTIMIZED_HLO_DOT };
|
||||
// `runtime` on a device `dev` with given `inputs`.
|
||||
xla::StatusOr<std::string> GetCompilerIr(
|
||||
IrExportStage stage, ProcessFunctionLibraryRuntime* pflr,
|
||||
absl::string_view func_name, Device* dev,
|
||||
absl::Span<const Tensor* const> inputs);
|
||||
absl::string_view func_name, Device* dev, EagerContext* context,
|
||||
absl::Span<const TensorHandle* const> inputs);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -643,6 +643,19 @@ class DefFunctionTest(xla_test.XLATestCase):
|
||||
self.assertIn('tuple',
|
||||
f.experimental_get_compiler_ir(l)())
|
||||
|
||||
def testConstantOnWrongDevice(self):
|
||||
with ops.device('device:{}:0'.format(self.device)):
|
||||
|
||||
s = math_ops.cast(random_ops.random_normal([2]), dtypes.int32)
|
||||
l = random_ops.random_normal([s[0] * s[1]])
|
||||
|
||||
@def_function.function(experimental_compile=True)
|
||||
def f(l):
|
||||
return array_ops.reshape(l, s)
|
||||
|
||||
self.assertIn('tuple',
|
||||
f.experimental_get_compiler_ir(l)())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ops.enable_eager_execution()
|
||||
|
@ -314,19 +314,10 @@ static std::string TFE_GetCompilerIr(py::handle& ctx,
|
||||
|
||||
TFE_InputTensorHandles handles = InputTFE_InputTensorHandles(inputs);
|
||||
|
||||
std::vector<const Tensor*> input_tensors;
|
||||
std::vector<const TensorHandle*> input_handles;
|
||||
for (TFE_TensorHandle* tensor_handle : handles) {
|
||||
AbstractTensorHandle* abstract_tensor_handle = unwrap(tensor_handle);
|
||||
TensorHandle* th = TensorHandleFromInterface(abstract_tensor_handle);
|
||||
|
||||
const Tensor* t;
|
||||
Status st = th->Tensor(&t);
|
||||
if (!st.ok()) {
|
||||
ThrowValueError(
|
||||
absl::StrFormat("Could not resolve tensor: '%s'", st.error_message())
|
||||
.c_str());
|
||||
}
|
||||
input_tensors.push_back(t);
|
||||
input_handles.push_back(TensorHandleFromInterface(abstract_tensor_handle));
|
||||
}
|
||||
|
||||
DeviceNameUtils::ParsedName input_device_name;
|
||||
@ -347,7 +338,7 @@ static std::string TFE_GetCompilerIr(py::handle& ctx,
|
||||
|
||||
xla::StatusOr<std::string> hlo_text =
|
||||
GetCompilerIr(selected_stage, context->pflr(), concrete_function_name,
|
||||
*selected_device, input_tensors);
|
||||
*selected_device, context, input_handles);
|
||||
|
||||
if (!hlo_text.ok()) {
|
||||
ThrowValueError(absl::StrFormat("Failed getting HLO text: '%s'",
|
||||
|
Loading…
x
Reference in New Issue
Block a user