Fix Physical Device Queue selector

This commit is contained in:
Florian RICHER 2025-05-21 13:57:14 +02:00
parent 422b2825ad
commit a9cf0cf4cf
5 changed files with 106 additions and 74 deletions

View file

@ -5,6 +5,8 @@ import org.lwjgl.system.MemoryStack;
import org.lwjgl.vulkan.*; import org.lwjgl.vulkan.*;
import org.tinylog.Logger; import org.tinylog.Logger;
import fr.mrdev023.vulkan_java.vk.utils.SuitablePhysicalDeviceFinder;
import java.nio.*; import java.nio.*;
import java.util.*; import java.util.*;
@ -15,9 +17,12 @@ public class Device {
private final PhysicalDevice physicalDevice; private final PhysicalDevice physicalDevice;
private final VkDevice vkDevice; private final VkDevice vkDevice;
private final Queue.GraphicsQueue graphicsQueue;
private final Queue.ComputeQueue computeQueue;
private final Queue.TransferQueue transferQueue;
public Device(PhysicalDevice physicalDevice) throws VulkanError { public Device(SuitablePhysicalDeviceFinder.MatchResult physicalDeviceMatch) throws VulkanError {
this.physicalDevice = physicalDevice; this.physicalDevice = physicalDeviceMatch.physicalDevice;
try (MemoryStack stack = MemoryStack.stackPush()) { try (MemoryStack stack = MemoryStack.stackPush()) {
// Define required extensions // Define required extensions
@ -60,6 +65,10 @@ public class Device {
vkCheck(vkCreateDevice(physicalDevice.getVkPhysicalDevice(), deviceCreateInfo, null, pp), vkCheck(vkCreateDevice(physicalDevice.getVkPhysicalDevice(), deviceCreateInfo, null, pp),
"Failed to create device"); "Failed to create device");
vkDevice = new VkDevice(pp.get(0), physicalDevice.getVkPhysicalDevice(), deviceCreateInfo); vkDevice = new VkDevice(pp.get(0), physicalDevice.getVkPhysicalDevice(), deviceCreateInfo);
graphicsQueue = new Queue.GraphicsQueue(this, physicalDeviceMatch.graphicsQueueFamilyIndex, 0);
computeQueue = new Queue.ComputeQueue(this, physicalDeviceMatch.computeQueueFamilyIndex, 1);
transferQueue = new Queue.TransferQueue(this, physicalDeviceMatch.transferQueueFamilyIndex, 2);
} }
Logger.debug("Vulkan device created"); Logger.debug("Vulkan device created");
@ -111,4 +120,16 @@ public class Device {
public void waitIdle() { public void waitIdle() {
vkDeviceWaitIdle(vkDevice); vkDeviceWaitIdle(vkDevice);
} }
public Queue.GraphicsQueue getGraphicsQueue() {
return graphicsQueue;
}
public Queue.ComputeQueue getComputeQueue() {
return computeQueue;
}
public Queue.TransferQueue getTransferQueue() {
return transferQueue;
}
} }

View file

@ -31,29 +31,20 @@ public class Queue {
} }
public static class GraphicsQueue extends Queue { public static class GraphicsQueue extends Queue {
public GraphicsQueue(Device device, int queueFamilyIndex, int queueIndex) {
public GraphicsQueue(Device device, int queueIndex) { super(device, queueFamilyIndex, queueIndex);
super(device, getGraphicsQueueFamilyIndex(device), queueIndex);
} }
}
private static int getGraphicsQueueFamilyIndex(Device device) { public static class ComputeQueue extends Queue {
int index = -1; public ComputeQueue(Device device, int queueFamilyIndex, int queueIndex) {
PhysicalDevice physicalDevice = device.getPhysicalDevice(); super(device, queueFamilyIndex, queueIndex);
VkQueueFamilyProperties.Buffer queuePropsBuff = physicalDevice.getVkQueueFamilyProps(); }
int numQueuesFamilies = queuePropsBuff.capacity(); }
for (int i = 0; i < numQueuesFamilies; i++) {
VkQueueFamilyProperties props = queuePropsBuff.get(i);
boolean graphicsQueue = (props.queueFlags() & VK_QUEUE_GRAPHICS_BIT) != 0;
if (graphicsQueue) {
index = i;
break;
}
}
if (index < 0) { public static class TransferQueue extends Queue {
throw new RuntimeException("Failed to get graphics Queue family index"); public TransferQueue(Device device, int queueFamilyIndex, int queueIndex) {
} super(device, queueFamilyIndex, queueIndex);
return index;
} }
} }
} }

View file

@ -15,7 +15,6 @@ public class Vulkan {
private static PhysicalDevice physicalDevice; private static PhysicalDevice physicalDevice;
private static Device device; private static Device device;
private static Surface surface; private static Surface surface;
private static Queue.GraphicsQueue graphicsQueue;
public static void init() throws VulkanError { public static void init() throws VulkanError {
if (!glfwVulkanSupported()) { if (!glfwVulkanSupported()) {
@ -26,11 +25,11 @@ public class Vulkan {
surface = new Surface(instance, Display.getWindow()); surface = new Surface(instance, Display.getWindow());
var criteria = new SuitablePhysicalDeviceFinder.Criteria() var criteria = new SuitablePhysicalDeviceFinder.Criteria()
.withGraphicsQueue(true) .withGraphicsQueue(true)
.withComputeQueue(true) .withComputeQueue(true)
.withTransferQueue(true) .withTransferQueue(true)
.withSurfaceSupport(surface) .withSurfaceSupport(surface)
.withExtensions(Set.of(KHRDynamicRendering.VK_KHR_DYNAMIC_RENDERING_EXTENSION_NAME)); .withExtensions(Set.of(KHRDynamicRendering.VK_KHR_DYNAMIC_RENDERING_EXTENSION_NAME));
var physicalDeviceMatch = SuitablePhysicalDeviceFinder.findBestPhysicalDevice(instance, criteria); var physicalDeviceMatch = SuitablePhysicalDeviceFinder.findBestPhysicalDevice(instance, criteria);
if (physicalDeviceMatch == null) { if (physicalDeviceMatch == null) {
@ -38,8 +37,7 @@ public class Vulkan {
} }
physicalDevice = physicalDeviceMatch.physicalDevice; physicalDevice = physicalDeviceMatch.physicalDevice;
device = new Device(physicalDevice); device = new Device(physicalDeviceMatch);
graphicsQueue = new Queue.GraphicsQueue(device, 0);
} }
public static void destroy() { public static void destroy() {

View file

@ -17,17 +17,16 @@ import org.tinylog.Logger;
import fr.mrdev023.vulkan_java.vk.VulkanUtils; import fr.mrdev023.vulkan_java.vk.VulkanUtils;
/** /**
* This class is used to store the instance extensions to use to create the Vulkan instance. * This class is used to store the instance extensions to use to create the
* Vulkan instance.
* *
* @see InstanceExtensions.Selector to select the instance extensions * @see InstanceExtensions.Selector to select the instance extensions
*/ */
public class InstanceExtensions { public class InstanceExtensions {
private Set<String> instanceExtensions; private Set<String> instanceExtensions;
private PointerBuffer glfwRequiredExtensions;
private InstanceExtensions(Set<String> instanceExtensions, PointerBuffer glfwRequiredExtensions) { private InstanceExtensions(Set<String> instanceExtensions) {
this.instanceExtensions = instanceExtensions; this.instanceExtensions = instanceExtensions;
this.glfwRequiredExtensions = glfwRequiredExtensions;
} }
public boolean hasInstanceExtensions() { public boolean hasInstanceExtensions() {
@ -39,10 +38,9 @@ public class InstanceExtensions {
} }
public PointerBuffer writeToStack(MemoryStack stack) { public PointerBuffer writeToStack(MemoryStack stack) {
int numExtensions = instanceExtensions.size() + glfwRequiredExtensions.remaining(); int numExtensions = instanceExtensions.size();
PointerBuffer requiredExtensions = stack.mallocPointer(numExtensions);
requiredExtensions.put(glfwRequiredExtensions); PointerBuffer requiredExtensions = stack.mallocPointer(numExtensions);
for (String extension : instanceExtensions) { for (String extension : instanceExtensions) {
requiredExtensions.put(stack.UTF8(extension)); requiredExtensions.put(stack.UTF8(extension));
} }
@ -62,15 +60,15 @@ public class InstanceExtensions {
* *
* <h5>Specification</h5> * <h5>Specification</h5>
* <ul> * <ul>
* <li>Get the GLFW required extensions to create the Window Surface</li> * <li>Get the GLFW required extensions to create the Window Surface</li>
* <li>Get the portability extensions if the OS needs it</li> * <li>Get the portability extensions if the OS needs it</li>
* <li>Add Debug Utils extension if validation layers are enabled</li> * <li>Add Debug Utils extension if validation layers are enabled</li>
* </ul> * </ul>
* *
* <h5>See Also</h5> * <h5>See Also</h5>
* <ul> * <ul>
* <li>{@link #selectInstanceExtensions()}</li> * <li>{@link #selectInstanceExtensions()}</li>
* <li>{@link #withValidationLayers(InstanceValidationLayers)}</li> * <li>{@link #withValidationLayers(InstanceValidationLayers)}</li>
* </ul> * </ul>
*/ */
public static class Selector { public static class Selector {
@ -90,16 +88,11 @@ public class InstanceExtensions {
Set<String> instanceExtensions = getInstanceExtensions(); Set<String> instanceExtensions = getInstanceExtensions();
log("Supported instance extensions", instanceExtensions); log("Supported instance extensions", instanceExtensions);
// GLFW Extension Set<String> glfwExtensions = getGLFWRequiredExtensions();
PointerBuffer glfwExtensions = GLFWVulkan.glfwGetRequiredInstanceExtensions();
if (glfwExtensions == null) {
throw new RuntimeException("Failed to find the GLFW platform surface extensions");
}
Set<String> portabilityExtensions = getPortabilityExtensions(instanceExtensions); Set<String> portabilityExtensions = getPortabilityExtensions(instanceExtensions);
log("Portability extensions used", portabilityExtensions);
Set<String> selectedExtensions = new HashSet<>(); Set<String> selectedExtensions = new HashSet<>();
selectedExtensions.addAll(glfwExtensions);
selectedExtensions.addAll(portabilityExtensions); selectedExtensions.addAll(portabilityExtensions);
if (validationLayers != null && validationLayers.hasValidationLayers()) { if (validationLayers != null && validationLayers.hasValidationLayers()) {
@ -107,7 +100,7 @@ public class InstanceExtensions {
} }
log("Selected instance extensions", selectedExtensions); log("Selected instance extensions", selectedExtensions);
return new InstanceExtensions(selectedExtensions, glfwExtensions); return new InstanceExtensions(selectedExtensions);
} }
private Set<String> getInstanceExtensions() { private Set<String> getInstanceExtensions() {
@ -118,7 +111,8 @@ public class InstanceExtensions {
VK10.vkEnumerateInstanceExtensionProperties((String) null, numExtensionsBuf, null); VK10.vkEnumerateInstanceExtensionProperties((String) null, numExtensionsBuf, null);
int numExtensions = numExtensionsBuf.get(0); int numExtensions = numExtensionsBuf.get(0);
VkExtensionProperties.Buffer instanceExtensionsProps = VkExtensionProperties.calloc(numExtensions, stack); VkExtensionProperties.Buffer instanceExtensionsProps = VkExtensionProperties.calloc(numExtensions,
stack);
VK10.vkEnumerateInstanceExtensionProperties((String) null, numExtensionsBuf, instanceExtensionsProps); VK10.vkEnumerateInstanceExtensionProperties((String) null, numExtensionsBuf, instanceExtensionsProps);
for (int i = 0; i < numExtensions; i++) { for (int i = 0; i < numExtensions; i++) {
VkExtensionProperties props = instanceExtensionsProps.get(i); VkExtensionProperties props = instanceExtensionsProps.get(i);
@ -135,8 +129,10 @@ public class InstanceExtensions {
var osType = VulkanUtils.getOS(); var osType = VulkanUtils.getOS();
if (osType == VulkanUtils.OSType.MACOS) { if (osType == VulkanUtils.OSType.MACOS) {
if (!instanceExtensions.contains(KHRPortabilityEnumeration.VK_KHR_PORTABILITY_ENUMERATION_EXTENSION_NAME)) { if (!instanceExtensions
throw new RuntimeException("Vulkan instance does not support portability enumeration extension but it's required for MacOS"); .contains(KHRPortabilityEnumeration.VK_KHR_PORTABILITY_ENUMERATION_EXTENSION_NAME)) {
throw new RuntimeException(
"Vulkan instance does not support portability enumeration extension but it's required for MacOS");
} }
portabilityExtensions.add(KHRPortabilityEnumeration.VK_KHR_PORTABILITY_ENUMERATION_EXTENSION_NAME); portabilityExtensions.add(KHRPortabilityEnumeration.VK_KHR_PORTABILITY_ENUMERATION_EXTENSION_NAME);
} }
@ -144,6 +140,19 @@ public class InstanceExtensions {
return portabilityExtensions; return portabilityExtensions;
} }
private Set<String> getGLFWRequiredExtensions() {
PointerBuffer glfwExtensionsBuffer = GLFWVulkan.glfwGetRequiredInstanceExtensions();
if (glfwExtensionsBuffer == null) {
throw new RuntimeException("Failed to find the GLFW platform surface extensions");
}
Set<String> glfwExtensions = new HashSet<>();
for (int i = 0; i < glfwExtensionsBuffer.remaining(); i++) {
glfwExtensions.add(glfwExtensionsBuffer.getStringUTF8(i));
}
return glfwExtensions;
}
private void log(String title, Set<String> layers) { private void log(String title, Set<String> layers) {
Logger.debug("{} ({})", title, layers.size()); Logger.debug("{} ({})", title, layers.size());
for (String layer : layers) { for (String layer : layers) {

View file

@ -35,37 +35,47 @@ public class SuitablePhysicalDeviceFinder {
return matchedPhysicalDevices.stream().min(Comparator.comparingInt(MatchResult::getScore)).orElse(null); return matchedPhysicalDevices.stream().min(Comparator.comparingInt(MatchResult::getScore)).orElse(null);
} }
private static MatchResult checkPhysicalDevice(MemoryStack stack, PhysicalDevice physicalDevice, Criteria criteria) { private static MatchResult checkPhysicalDevice(MemoryStack stack, PhysicalDevice physicalDevice,
Criteria criteria) {
int graphicsQueueFamilyIndex = -1; int graphicsQueueFamilyIndex = -1;
int computeQueueFamilyIndex = -1; int computeQueueFamilyIndex = -1;
int transferQueueFamilyIndex = -1; int transferQueueFamilyIndex = -1;
boolean surfaceSupport = false; boolean surfaceSupport = false;
IntBuffer presentSupport = stack.ints(VK10.VK_FALSE); IntBuffer presentSupport = stack.ints(VK10.VK_FALSE);
var vkQueueFamilyProps = physicalDevice.getVkQueueFamilyProps(); var vkQueueFamilyProps = physicalDevice.getVkQueueFamilyProps();
for (int i = 0; i < vkQueueFamilyProps.capacity(); i++) { for (int i = 0; i < vkQueueFamilyProps.capacity(); i++) {
var vkQueueFamilyProp = vkQueueFamilyProps.get(i); var vkQueueFamilyProp = vkQueueFamilyProps.get(i);
if ((vkQueueFamilyProp.queueFlags() & VK10.VK_QUEUE_GRAPHICS_BIT) != 0) { if (graphicsQueueFamilyIndex == -1 && (vkQueueFamilyProp.queueFlags() & VK10.VK_QUEUE_GRAPHICS_BIT) != 0) {
graphicsQueueFamilyIndex = i; graphicsQueueFamilyIndex = i;
} }
if ((vkQueueFamilyProp.queueFlags() & VK10.VK_QUEUE_COMPUTE_BIT) != 0) { if (computeQueueFamilyIndex == -1 && (vkQueueFamilyProp.queueFlags() & VK10.VK_QUEUE_COMPUTE_BIT) != 0) {
computeQueueFamilyIndex = i; computeQueueFamilyIndex = i;
} }
if ((vkQueueFamilyProp.queueFlags() & VK10.VK_QUEUE_TRANSFER_BIT) != 0) { if (transferQueueFamilyIndex == -1 && (vkQueueFamilyProp.queueFlags() & VK10.VK_QUEUE_TRANSFER_BIT) != 0) {
transferQueueFamilyIndex = i; transferQueueFamilyIndex = i;
} }
KHRSurface.vkGetPhysicalDeviceSurfaceSupportKHR(physicalDevice.getVkPhysicalDevice(), i, criteria.withSurfaceSupport.getVkSurface(), presentSupport); KHRSurface.vkGetPhysicalDeviceSurfaceSupportKHR(physicalDevice.getVkPhysicalDevice(), i,
criteria.withSurfaceSupport.getVkSurface(), presentSupport);
if (presentSupport.get(0) == VK10.VK_TRUE) { if (presentSupport.get(0) == VK10.VK_TRUE) {
surfaceSupport = true; surfaceSupport = true;
} }
if ((!criteria.withGraphicsQueue || graphicsQueueFamilyIndex != -1) &&
(!criteria.withComputeQueue || computeQueueFamilyIndex != -1) &&
(!criteria.withTransferQueue || transferQueueFamilyIndex != -1) &&
(criteria.withSurfaceSupport == null || surfaceSupport)) {
// We found a suitable queue family, we can break the loop
break;
}
} }
return new MatchResult(physicalDevice, graphicsQueueFamilyIndex, computeQueueFamilyIndex, transferQueueFamilyIndex, surfaceSupport); return new MatchResult(physicalDevice, graphicsQueueFamilyIndex, computeQueueFamilyIndex,
transferQueueFamilyIndex, surfaceSupport);
} }
public static class MatchResult { public static class MatchResult {
@ -75,7 +85,8 @@ public class SuitablePhysicalDeviceFinder {
public final int computeQueueFamilyIndex; public final int computeQueueFamilyIndex;
public final int transferQueueFamilyIndex; public final int transferQueueFamilyIndex;
public MatchResult(PhysicalDevice physicalDevice, int graphicsQueueFamilyIndex, int computeQueueFamilyIndex, int transferQueueFamilyIndex, boolean surfaceSupport) { public MatchResult(PhysicalDevice physicalDevice, int graphicsQueueFamilyIndex, int computeQueueFamilyIndex,
int transferQueueFamilyIndex, boolean surfaceSupport) {
this.physicalDevice = physicalDevice; this.physicalDevice = physicalDevice;
this.graphicsQueueFamilyIndex = graphicsQueueFamilyIndex; this.graphicsQueueFamilyIndex = graphicsQueueFamilyIndex;
this.computeQueueFamilyIndex = computeQueueFamilyIndex; this.computeQueueFamilyIndex = computeQueueFamilyIndex;
@ -89,9 +100,9 @@ public class SuitablePhysicalDeviceFinder {
Logger.debug("\t\tName: {}", physicalDevice.getVkPhysicalDeviceProperties().deviceNameString()); Logger.debug("\t\tName: {}", physicalDevice.getVkPhysicalDeviceProperties().deviceNameString());
int apiVersion = physicalDevice.getVkPhysicalDeviceProperties().apiVersion(); int apiVersion = physicalDevice.getVkPhysicalDeviceProperties().apiVersion();
Logger.debug("\t\tAPI version: {}.{}.{}", Logger.debug("\t\tAPI version: {}.{}.{}",
VK10.VK_API_VERSION_MAJOR(apiVersion), VK10.VK_API_VERSION_MAJOR(apiVersion),
VK10.VK_API_VERSION_MINOR(apiVersion), VK10.VK_API_VERSION_MINOR(apiVersion),
VK10.VK_API_VERSION_PATCH(apiVersion)); VK10.VK_API_VERSION_PATCH(apiVersion));
Logger.debug("\t\tDevice type: {}", physicalDevice.getVkPhysicalDeviceProperties().deviceType()); Logger.debug("\t\tDevice type: {}", physicalDevice.getVkPhysicalDeviceProperties().deviceType());
Logger.debug("\tRequired supports checking report"); Logger.debug("\tRequired supports checking report");
@ -101,7 +112,7 @@ public class SuitablePhysicalDeviceFinder {
Logger.debug("\t\t[FAILED] Graphics queue is not supported"); Logger.debug("\t\t[FAILED] Graphics queue is not supported");
return false; return false;
} else { } else {
Logger.debug("\t\t[OK] Graphics queue is supported"); Logger.debug("\t\t[OK] Graphics queue is supported (family index: {})", graphicsQueueFamilyIndex);
} }
if (!criteria.withComputeQueue) { if (!criteria.withComputeQueue) {
@ -110,7 +121,7 @@ public class SuitablePhysicalDeviceFinder {
Logger.debug("\t\t[FAILED] Compute queue is not supported"); Logger.debug("\t\t[FAILED] Compute queue is not supported");
return false; return false;
} else { } else {
Logger.debug("\t\t[OK] Compute queue is supported"); Logger.debug("\t\t[OK] Compute queue is supported (family index: {})", computeQueueFamilyIndex);
} }
if (!criteria.withTransferQueue) { if (!criteria.withTransferQueue) {
@ -119,7 +130,7 @@ public class SuitablePhysicalDeviceFinder {
Logger.debug("\t\t[FAILED] Transfer queue is not supported"); Logger.debug("\t\t[FAILED] Transfer queue is not supported");
return false; return false;
} else { } else {
Logger.debug("\t\t[OK] Transfer queue is supported"); Logger.debug("\t\t[OK] Transfer queue is supported (family index: {})", transferQueueFamilyIndex);
} }
if (criteria.withSurfaceSupport == null) { if (criteria.withSurfaceSupport == null) {
@ -133,8 +144,10 @@ public class SuitablePhysicalDeviceFinder {
if (criteria.extensions == null) { if (criteria.extensions == null) {
Logger.debug("\t\t[SKIPPED] Required extensions is empty"); Logger.debug("\t\t[SKIPPED] Required extensions is empty");
} else if (physicalDevice.getVkDeviceExtensions().stream().allMatch(extension -> criteria.extensions.contains(extension.extensionNameString()))) { } else if (physicalDevice.getVkDeviceExtensions().stream()
Logger.debug("\t\t[FAILED] Required extensions are not supported [{}]", String.join(", ", criteria.extensions)); .allMatch(extension -> criteria.extensions.contains(extension.extensionNameString()))) {
Logger.debug("\t\t[FAILED] Required extensions are not supported [{}]",
String.join(", ", criteria.extensions));
return false; return false;
} else { } else {
Logger.debug("\t\t[OK] Required extensions are supported [{}]", String.join(", ", criteria.extensions)); Logger.debug("\t\t[OK] Required extensions are supported [{}]", String.join(", ", criteria.extensions));