pyo -> Pyo pyo_or_throw -> PyoOrThrow PiperOrigin-RevId: 306876916 Change-Id: Idf846a2b13f93ab504ed277e229f473cf5a8605a
		
			
				
	
	
		
			78 lines
		
	
	
		
			3.0 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			78 lines
		
	
	
		
			3.0 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | 
						|
 | 
						|
Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
you may not use this file except in compliance with the License.
 | 
						|
You may obtain a copy of the License at
 | 
						|
 | 
						|
    http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
 | 
						|
Unless required by applicable law or agreed to in writing, software
 | 
						|
distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
See the License for the specific language governing permissions and
 | 
						|
limitations under the License.
 | 
						|
==============================================================================*/
 | 
						|
 | 
						|
#include "tensorflow/c/eager/custom_device_testutil.h"
 | 
						|
 | 
						|
#include "Python.h"
 | 
						|
#include "pybind11/pybind11.h"
 | 
						|
#include "pybind11/stl.h"
 | 
						|
#include "tensorflow/c/c_api.h"
 | 
						|
#include "tensorflow/c/c_api_experimental.h"
 | 
						|
#include "tensorflow/c/eager/c_api.h"
 | 
						|
#include "tensorflow/c/eager/c_api_experimental.h"
 | 
						|
#include "tensorflow/c/tf_status.h"
 | 
						|
#include "tensorflow/c/tf_status_helper.h"
 | 
						|
#include "tensorflow/python/lib/core/py_exception_registry.h"
 | 
						|
#include "tensorflow/python/lib/core/pybind11_lib.h"
 | 
						|
#include "tensorflow/python/lib/core/pybind11_status.h"
 | 
						|
#include "tensorflow/python/lib/core/safe_ptr.h"
 | 
						|
#include "tensorflow/python/util/util.h"
 | 
						|
 | 
						|
namespace py = pybind11;
 | 
						|
 | 
						|
void CallDelete_Flag(PyObject* capsule) {
 | 
						|
  delete reinterpret_cast<bool*>(PyCapsule_GetPointer(capsule, "flag"));
 | 
						|
}
 | 
						|
 | 
						|
void CallDelete_Device(PyObject* capsule) {
 | 
						|
  delete reinterpret_cast<TFE_CustomDevice*>(
 | 
						|
      PyCapsule_GetPointer(capsule, "TFE_CustomDevice"));
 | 
						|
}
 | 
						|
 | 
						|
void CallDelete_DeviceInfo(PyObject* capsule) {
 | 
						|
  PyErr_SetString(PyExc_AssertionError,
 | 
						|
                  "Capsule should be consumed by TFE_Py_RegisterCustomDevice");
 | 
						|
}
 | 
						|
 | 
						|
PYBIND11_MODULE(custom_device_testutil, m) {
 | 
						|
  m.def("GetLoggingDeviceCapsules", [](const char* name) {
 | 
						|
    bool* arrived_flag = new bool;
 | 
						|
    bool* executed_flag = new bool;
 | 
						|
    *arrived_flag = false;
 | 
						|
    *executed_flag = false;
 | 
						|
    tensorflow::Safe_PyObjectPtr arrived_capsule(
 | 
						|
        PyCapsule_New(arrived_flag, "flag", &CallDelete_Flag));
 | 
						|
    tensorflow::Safe_PyObjectPtr executed_capsule(
 | 
						|
        PyCapsule_New(executed_flag, "flag", &CallDelete_Flag));
 | 
						|
    TFE_CustomDevice* device;
 | 
						|
    void* device_info;
 | 
						|
    AllocateLoggingDevice(name, arrived_flag, executed_flag, &device,
 | 
						|
                          &device_info);
 | 
						|
    tensorflow::Safe_PyObjectPtr device_capsule(
 | 
						|
        PyCapsule_New(device, "TFE_CustomDevice", &CallDelete_Device));
 | 
						|
    tensorflow::Safe_PyObjectPtr device_info_capsule(PyCapsule_New(
 | 
						|
        device_info, "TFE_CustomDevice_DeviceInfo", &CallDelete_DeviceInfo));
 | 
						|
    return tensorflow::PyoOrThrow(
 | 
						|
        PyTuple_Pack(4, device_capsule.get(), device_info_capsule.get(),
 | 
						|
                     arrived_capsule.get(), executed_capsule.get()));
 | 
						|
  });
 | 
						|
  m.def("FlagValue", [](py::capsule flag_capsule) {
 | 
						|
    bool* flag = reinterpret_cast<bool*>(
 | 
						|
        PyCapsule_GetPointer(flag_capsule.ptr(), "flag"));
 | 
						|
    if (PyErr_Occurred()) throw py::error_already_set();
 | 
						|
    return *flag;
 | 
						|
  });
 | 
						|
}
 |