STT-tensorflow/tensorflow/c/eager/custom_device_testutil.h
Allen Lavoie a764f3ab76 Give custom devices the option to do type-based dispatch for ops with no explicit placement
When there is a custom device input and one or more physical device inputs to an op, presents the op to the custom device but indicates that the user did not explicitly request the placement (via the device property of the passed op).

Custom devices which want to stick to strict scope-based placement can either copy off the inputs and run the op on the default device or throw an error. The parallel device will stick with scope-only dispatch for now.

PiperOrigin-RevId: 328840123
Change-Id: Ic7490c0700a7ca5c74fd362211fa2fc9e008051c
2020-08-27 16:39:24 -07:00

37 lines
1.6 KiB
C

/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_CUSTOM_DEVICE_TESTUTIL_H_
#define TENSORFLOW_C_EAGER_CUSTOM_DEVICE_TESTUTIL_H_
// A simple logging device to test custom device registration.
#include <memory>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/tf_status.h"
void RegisterLoggingDevice(TFE_Context* context, const char* name,
bool strict_scope_placement, bool* arrived_flag,
bool* executed_flag, TF_Status* status);
void AllocateLoggingDevice(const char* name, bool* arrived_flag,
bool* executed_flag, TFE_CustomDevice** device,
void** device_info);
TFE_TensorHandle* UnpackTensorHandle(TFE_TensorHandle* logged_tensor_handle,
TF_Status* status);
#endif // TENSORFLOW_C_EAGER_CUSTOM_DEVICE_TESTUTIL_H_