From aeaed68efb10e519e2c79d35930d82b46a0ccc0a Mon Sep 17 00:00:00 2001 From: kearwood Date: Fri, 9 Sep 2022 00:36:22 -0700 Subject: [PATCH] Moved ShaderStage to KRShader and expanded to include all stages. Added mapping functions, getShaderStageFromExtension and getShaderStageFlagBitsFromShaderStage. KRShader::m_stage is now typed as ShaderStage. --- kraken/KRContext.cpp | 2 +- kraken/KRPipeline.cpp | 27 +++---------- kraken/KRPipeline.h | 16 ++------ kraken/KRShader.cpp | 78 ++++++++++++++++++++++++++++---------- kraken/KRShader.h | 31 +++++++++++++-- kraken/KRSourceManager.cpp | 4 +- 6 files changed, 99 insertions(+), 59 deletions(-) diff --git a/kraken/KRContext.cpp b/kraken/KRContext.cpp index 9486f8b..d3c93e2 100755 --- a/kraken/KRContext.cpp +++ b/kraken/KRContext.cpp @@ -337,7 +337,7 @@ KRResource* KRContext::loadResource(const std::string& file_name, KRDataBlock* d } else if (extension.compare("spv") == 0) { // SPIR-V shader binary resource = m_pShaderManager->load(name, extension, data); - } else if (getShaderStageFromExtension(extension.c_str()) != 0) { + } else if (getShaderStageFromExtension(extension.c_str()) != ShaderStage::Invalid) { // Shader source resource = m_pSourceManager->load(name, extension, data); } else if (extension.compare("glsl") == 0) { diff --git a/kraken/KRPipeline.cpp b/kraken/KRPipeline.cpp index 47db428..71cc735 100644 --- a/kraken/KRPipeline.cpp +++ b/kraken/KRPipeline.cpp @@ -178,12 +178,12 @@ KRPipeline::KRPipeline(KRContext& context, KRSurface& surface, const PipelineInf binding.descriptorType = static_cast(binding_reflect.descriptor_type); binding.descriptorCount = binding_reflect.count; binding.pImmutableSamplers = nullptr; - binding.stageFlags = shader->getShaderStage(); + binding.stageFlags = shader->getShaderStageFlagBits(); } VkPipelineShaderStageCreateInfo& stageInfo = stages[stage_count++]; stageInfo.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO; - stageInfo.stage = shader->getShaderStage(); + stageInfo.stage = shader->getShaderStageFlagBits(); if (stageInfo.stage == VK_SHADER_STAGE_VERTEX_BIT) { for (uint32_t i = 0; i < reflection->input_variable_count; i++) { @@ -206,10 +206,10 @@ KRPipeline::KRPipeline(KRContext& context, KRSurface& surface, const PipelineInf } } - initPushConstantStage(ShaderStages::vertex, reflection); + initPushConstantStage(ShaderStage::vert, reflection); } else if (stageInfo.stage == VK_SHADER_STAGE_FRAGMENT_BIT) { - initPushConstantStage(ShaderStages::fragment, reflection); + initPushConstantStage(ShaderStage::frag, reflection); } else { // failed! TODO - Error handling } @@ -395,22 +395,7 @@ KRPipeline::KRPipeline(KRContext& context, KRSurface& surface, const PipelineInf VkPushConstantRange push_constant{}; push_constant.offset = 0; push_constant.size = pushConstants.bufferSize; - - switch (static_cast(iStage)) { - case ShaderStages::vertex: - push_constant.stageFlags = VK_SHADER_STAGE_VERTEX_BIT; - break; - case ShaderStages::fragment: - push_constant.stageFlags = VK_SHADER_STAGE_FRAGMENT_BIT; - break; - case ShaderStages::geometry: - push_constant.stageFlags = VK_SHADER_STAGE_GEOMETRY_BIT; - break; - case ShaderStages::compute: - push_constant.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT; - break; - } - + push_constant.stageFlags = getShaderStageFlagBitsFromShaderStage(static_cast(iStage)); pushConstantsLayoutInfo.pPushConstantRanges = &push_constant; pushConstantsLayoutInfo.pushConstantRangeCount = 1; @@ -511,7 +496,7 @@ KRPipeline::~KRPipeline() } } -void KRPipeline::initPushConstantStage(ShaderStages stage, const SpvReflectShaderModule* reflection) +void KRPipeline::initPushConstantStage(ShaderStage stage, const SpvReflectShaderModule* reflection) { PushConstantStageInfo& pushConstants = m_pushConstants[static_cast(stage)]; for (int i = 0; i < reflection->push_constant_block_count; i++) { diff --git a/kraken/KRPipeline.h b/kraken/KRPipeline.h index 50b8d26..64ddb16 100644 --- a/kraken/KRPipeline.h +++ b/kraken/KRPipeline.h @@ -37,6 +37,7 @@ #include "KRNode.h" #include "KRViewport.h" #include "KRMesh.h" +#include "KRShader.h" class KRShader; class KRSurface; @@ -292,17 +293,6 @@ public: static const size_t kPushConstantCount = static_cast(PushConstant::NUM_PUSH_CONSTANTS); - enum class ShaderStages : uint8_t - { - vertex = 0, - fragment, - geometry, - compute, - shaderStageCount - }; - - static const size_t kShaderStageCount = static_cast(ShaderStages::shaderStageCount); - bool hasPushConstant(PushConstant location) const; void setPushConstant(PushConstant location, float value); void setPushConstant(PushConstant location, int value); @@ -325,7 +315,7 @@ private: int bufferSize; VkPipelineLayout layout; }; - PushConstantStageInfo m_pushConstants[static_cast(ShaderStages::shaderStageCount)]; + PushConstantStageInfo m_pushConstants[static_cast(ShaderStage::ShaderStageCount)]; char m_szKey[256]; @@ -333,5 +323,5 @@ private: VkPipelineLayout m_pipelineLayout; VkPipeline m_graphicsPipeline; - void initPushConstantStage(ShaderStages stage, const SpvReflectShaderModule* reflection); + void initPushConstantStage(ShaderStage stage, const SpvReflectShaderModule* reflection); }; diff --git a/kraken/KRShader.cpp b/kraken/KRShader.cpp index abe7003..4efb1e3 100644 --- a/kraken/KRShader.cpp +++ b/kraken/KRShader.cpp @@ -32,41 +32,76 @@ #include "KRShader.h" #include "spirv_reflect.h" -VkShaderStageFlagBits getShaderStageFromExtension(const char* extension) +ShaderStage getShaderStageFromExtension(const char* extension) { if (strcmp(extension, "vert") == 0) { - return VK_SHADER_STAGE_VERTEX_BIT; + return ShaderStage::vert; } else if (strcmp(extension, "frag") == 0) { - return VK_SHADER_STAGE_FRAGMENT_BIT; + return ShaderStage::frag; } else if (strcmp(extension, "tesc") == 0) { - return VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT; + return ShaderStage::tesc; } else if (strcmp(extension, "tese") == 0) { - return VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT; + return ShaderStage::tese; } else if (strcmp(extension, "geom") == 0) { - return VK_SHADER_STAGE_GEOMETRY_BIT; + return ShaderStage::geom; } else if (strcmp(extension, "comp") == 0) { - return VK_SHADER_STAGE_COMPUTE_BIT; + return ShaderStage::comp; } else if (strcmp(extension, "mesh") == 0) { - return VK_SHADER_STAGE_MESH_BIT_NV; + return ShaderStage::mesh; } else if (strcmp(extension, "task") == 0) { - return VK_SHADER_STAGE_TASK_BIT_NV; + return ShaderStage::task; } else if (strcmp(extension, "rgen") == 0) { - return VK_SHADER_STAGE_RAYGEN_BIT_KHR; + return ShaderStage::rgen; } else if (strcmp(extension, "rint") == 0) { - return VK_SHADER_STAGE_INTERSECTION_BIT_KHR; + return ShaderStage::rint; } else if (strcmp(extension, "rahit") == 0) { - return VK_SHADER_STAGE_ANY_HIT_BIT_KHR; + return ShaderStage::rahit; } else if (strcmp(extension, "rchit") == 0) { - return VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR; + return ShaderStage::rchit; } else if (strcmp(extension, "rmiss") == 0) { - return VK_SHADER_STAGE_MISS_BIT_KHR; - } else if (strcmp(extension, "rmiss") == 0) { - return VK_SHADER_STAGE_CALLABLE_BIT_KHR; + return ShaderStage::rmiss; + } else if (strcmp(extension, "rcall") == 0) { + return ShaderStage::rcall; } else { - return (VkShaderStageFlagBits)0; + return ShaderStage::Invalid; } } +VkShaderStageFlagBits getShaderStageFlagBitsFromShaderStage(ShaderStage stage) +{ + switch (stage) { + case ShaderStage::vert: + return VK_SHADER_STAGE_VERTEX_BIT; + case ShaderStage::frag: + return VK_SHADER_STAGE_FRAGMENT_BIT; + case ShaderStage::tesc: + return VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT; + case ShaderStage::tese: + return VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT; + case ShaderStage::geom: + return VK_SHADER_STAGE_GEOMETRY_BIT; + case ShaderStage::comp: + return VK_SHADER_STAGE_COMPUTE_BIT; + case ShaderStage::mesh: + return VK_SHADER_STAGE_MESH_BIT_NV; + case ShaderStage::task: + return VK_SHADER_STAGE_TASK_BIT_NV; + case ShaderStage::rgen: + return VK_SHADER_STAGE_RAYGEN_BIT_KHR; + case ShaderStage::rint: + return VK_SHADER_STAGE_INTERSECTION_BIT_KHR; + case ShaderStage::rahit: + return VK_SHADER_STAGE_ANY_HIT_BIT_KHR; + case ShaderStage::rchit: + return VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR; + case ShaderStage::rmiss: + return VK_SHADER_STAGE_MISS_BIT_KHR; + case ShaderStage::rcall: + return VK_SHADER_STAGE_CALLABLE_BIT_KHR; + } + return (VkShaderStageFlagBits)0; +} + KRShader::KRShader(KRContext& context, std::string name, std::string extension) : KRResource(context, name) { m_pData = new KRDataBlock(); @@ -181,7 +216,12 @@ const SpvReflectShaderModule* KRShader::getReflection() return nullptr; } -VkShaderStageFlagBits KRShader::getShaderStage() const +ShaderStage KRShader::getShaderStage() const { return m_stage; -} \ No newline at end of file +} + +VkShaderStageFlagBits KRShader::getShaderStageFlagBits() const +{ + return getShaderStageFlagBitsFromShaderStage(m_stage); +} diff --git a/kraken/KRShader.h b/kraken/KRShader.h index 4b0216b..94c64c5 100644 --- a/kraken/KRShader.h +++ b/kraken/KRShader.h @@ -37,7 +37,30 @@ #include "KRResource.h" #include "spirv_reflect.h" -VkShaderStageFlagBits getShaderStageFromExtension(const char* extension); +enum class ShaderStage : uint8_t +{ + vert = 0, + frag, + tesc, + tese, + geom, + comp, + mesh, + task, + rgen, + rint, + rahit, + rchit, + rmiss, + rcall, + ShaderStageCount, + Invalid = 0xff +}; + +ShaderStage getShaderStageFromExtension(const char* extension); +VkShaderStageFlagBits getShaderStageFlagBitsFromShaderStage(ShaderStage stage); + +static const size_t kShaderStageCount = static_cast(ShaderStage::ShaderStageCount); class KRShader : public KRResource { @@ -55,7 +78,9 @@ public: KRDataBlock* getData(); const SpvReflectShaderModule* getReflection(); - VkShaderStageFlagBits getShaderStage() const; + ShaderStage getShaderStage() const; + VkShaderStageFlagBits getShaderStageFlagBits() const; + private: @@ -68,5 +93,5 @@ private: void parseReflection(); void freeReflection(); - VkShaderStageFlagBits m_stage; + ShaderStage m_stage; }; diff --git a/kraken/KRSourceManager.cpp b/kraken/KRSourceManager.cpp index 0f099be..4db226b 100644 --- a/kraken/KRSourceManager.cpp +++ b/kraken/KRSourceManager.cpp @@ -77,7 +77,7 @@ void KRSourceManager::add(KRSource* source) KRResource* KRSourceManager::loadResource(const std::string& name, const std::string& extension, KRDataBlock* data) { - if (getShaderStageFromExtension(extension.c_str()) != 0 || + if (getShaderStageFromExtension(extension.c_str()) != ShaderStage::Invalid || extension.compare("glsl") == 0 || extension.compare("options") == 0) { return load(name, extension, data); @@ -87,7 +87,7 @@ KRResource* KRSourceManager::loadResource(const std::string& name, const std::st KRResource* KRSourceManager::getResource(const std::string& name, const std::string& extension) { - if (getShaderStageFromExtension(extension.c_str()) != 0 || + if (getShaderStageFromExtension(extension.c_str()) != ShaderStage::Invalid || extension.compare("glsl") == 0 || extension.compare("options") == 0) { return get(name, extension);