From 7f2d8106f5d1dafde01a02c5bc442f5e9e20357f Mon Sep 17 00:00:00 2001
From: Fabio Riccardi <fricc@google.com>
Date: Tue, 26 May 2020 13:58:36 -0700
Subject: [PATCH] Introduce Vulkan API with integration tests.

PiperOrigin-RevId: 313262552
Change-Id: I7d56bba03b7752938bd5f0cf5b08315941118369
---
 tensorflow/lite/delegates/gpu/api.cc | 12 ++++++++++
 tensorflow/lite/delegates/gpu/api.h  | 33 ++++++++++++++++++++++++++--
 2 files changed, 43 insertions(+), 2 deletions(-)

diff --git a/tensorflow/lite/delegates/gpu/api.cc b/tensorflow/lite/delegates/gpu/api.cc
index 6c299e4965c..1a18fcb87f2 100644
--- a/tensorflow/lite/delegates/gpu/api.cc
+++ b/tensorflow/lite/delegates/gpu/api.cc
@@ -31,6 +31,12 @@ struct ObjectTypeGetter {
   ObjectType operator()(OpenClTexture) const {
     return ObjectType::OPENCL_TEXTURE;
   }
+  ObjectType operator()(VulkanBuffer) const {
+    return ObjectType::VULKAN_BUFFER;
+  }
+  ObjectType operator()(VulkanTexture) const {
+    return ObjectType::VULKAN_TEXTURE;
+  }
   ObjectType operator()(CpuMemory) const { return ObjectType::CPU_MEMORY; }
 };
 
@@ -42,6 +48,8 @@ struct ObjectValidityChecker {
   }
   bool operator()(OpenClBuffer obj) const { return obj.memobj; }
   bool operator()(OpenClTexture obj) const { return obj.memobj; }
+  bool operator()(VulkanBuffer obj) const { return obj.memory; }
+  bool operator()(VulkanTexture obj) const { return obj.memory; }
   bool operator()(CpuMemory obj) const {
     return obj.data != nullptr && obj.size_bytes > 0 &&
            (data_type == DataType::UNKNOWN ||
@@ -81,6 +89,10 @@ bool IsObjectPresent(ObjectType type, const TensorObject& obj) {
       return absl::get_if<OpenClBuffer>(&obj);
     case ObjectType::OPENCL_TEXTURE:
       return absl::get_if<OpenClTexture>(&obj);
+    case ObjectType::VULKAN_BUFFER:
+      return absl::get_if<VulkanBuffer>(&obj);
+    case ObjectType::VULKAN_TEXTURE:
+      return absl::get_if<VulkanTexture>(&obj);
     case ObjectType::UNKNOWN:
       return false;
   }
diff --git a/tensorflow/lite/delegates/gpu/api.h b/tensorflow/lite/delegates/gpu/api.h
index 2a531f1f81b..1dfeeebd700 100644
--- a/tensorflow/lite/delegates/gpu/api.h
+++ b/tensorflow/lite/delegates/gpu/api.h
@@ -71,6 +71,8 @@ enum class ObjectType {
   CPU_MEMORY,
   OPENCL_TEXTURE,
   OPENCL_BUFFER,
+  VULKAN_BUFFER,
+  VULKAN_TEXTURE
 };
 
 struct OpenGlBuffer {
@@ -104,11 +106,37 @@ struct OpenClTexture {
   // TODO(akulik): should it specify texture format?
 };
 
+struct VulkanBuffer {
+  VulkanBuffer() = default;
+  explicit VulkanBuffer(VkBuffer buffer_, VkDeviceSize size_,
+                        VkDeviceMemory memory_, VkDeviceSize offset_)
+      : buffer(buffer_), size(size_), memory(memory_), offset(offset_) {}
+
+  VkBuffer buffer;
+  VkDeviceSize size;
+  VkDeviceMemory memory;
+  VkDeviceSize offset;
+};
+
+struct VulkanTexture {
+  VulkanTexture() = default;
+  explicit VulkanTexture(VkDeviceMemory new_memory) : memory(new_memory) {}
+
+  VkImage image;
+  VkImageView image_view;
+  VkFormat format;
+  VkExtent3D extent;
+  VkDeviceMemory memory;
+  VkDeviceSize offset;
+};
+
 struct VulkanMemory {
   VulkanMemory() = default;
   explicit VulkanMemory(VkDeviceMemory new_memory) : memory(new_memory) {}
 
   VkDeviceMemory memory;
+  VkDeviceSize size;
+  VkDeviceSize offset;
 };
 
 struct CpuMemory {
@@ -195,8 +223,9 @@ bool IsValid(const TensorObjectDef& def);
 // @return the number of elements in a tensor object.
 uint32_t NumElements(const TensorObjectDef& def);
 
-using TensorObject = absl::variant<absl::monostate, OpenGlBuffer, OpenGlTexture,
-                                   CpuMemory, OpenClBuffer, OpenClTexture>;
+using TensorObject =
+    absl::variant<absl::monostate, OpenGlBuffer, OpenGlTexture, CpuMemory,
+                  OpenClBuffer, OpenClTexture, VulkanBuffer, VulkanTexture>;
 
 // @return true if object is set and corresponding values are defined.
 bool IsValid(const TensorObjectDef& def, const TensorObject& object);