From f20f7f73d6f2c1b9f727e4a5ee7a796bfb2b223c Mon Sep 17 00:00:00 2001 From: kearwood Date: Thu, 8 Sep 2022 23:58:24 -0700 Subject: [PATCH] Centralized shader file sub-extension to shader stage mapping to KRShader. --- kraken/KRContext.cpp | 43 ++------------------------------------ kraken/KRPipeline.cpp | 16 ++++---------- kraken/KRShader.cpp | 42 +++++++++++++++++++++++++++++++++++++ kraken/KRShader.h | 5 +++++ kraken/KRSourceManager.cpp | 31 +++------------------------ 5 files changed, 56 insertions(+), 81 deletions(-) diff --git a/kraken/KRContext.cpp b/kraken/KRContext.cpp index 4c6366d..9486f8b 100755 --- a/kraken/KRContext.cpp +++ b/kraken/KRContext.cpp @@ -337,47 +337,8 @@ KRResource* KRContext::loadResource(const std::string& file_name, KRDataBlock* d } else if (extension.compare("spv") == 0) { // SPIR-V shader binary resource = m_pShaderManager->load(name, extension, data); - } else if (extension.compare("vert") == 0) { - // vertex shader - resource = m_pSourceManager->load(name, extension, data); - } else if (extension.compare("frag") == 0) { - // fragment shader - resource = m_pSourceManager->load(name, extension, data); - } else if (extension.compare("tesc") == 0) { - // tessellation control shader - resource = m_pSourceManager->load(name, extension, data); - } else if (extension.compare("tese") == 0) { - // tessellation evaluation shader - resource = m_pSourceManager->load(name, extension, data); - } else if (extension.compare("geom") == 0) { - // geometry shader - resource = m_pSourceManager->load(name, extension, data); - } else if (extension.compare("comp") == 0) { - // compute shader - resource = m_pSourceManager->load(name, extension, data); - } else if (extension.compare("mesh") == 0) { - // mesh shader - resource = m_pSourceManager->load(name, extension, data); - } else if (extension.compare("task") == 0) { - // task shader - resource = m_pSourceManager->load(name, extension, data); - } else if (extension.compare("rgen") == 0) { - // ray generation shader - resource = m_pSourceManager->load(name, extension, data); - } else if (extension.compare("rint") == 0) { - // ray intersection shader - resource = m_pSourceManager->load(name, extension, data); - } else if (extension.compare("rahit") == 0) { - // ray any hit shader - resource = m_pSourceManager->load(name, extension, data); - } else if (extension.compare("rchit") == 0) { - // ray closest hit shader - resource = m_pSourceManager->load(name, extension, data); - } else if (extension.compare("rmiss") == 0) { - // ray miss shader - resource = m_pSourceManager->load(name, extension, data); - } else if (extension.compare("rcall") == 0) { - // ray callable shader + } else if (getShaderStageFromExtension(extension.c_str()) != 0) { + // Shader source resource = m_pSourceManager->load(name, extension, data); } else if (extension.compare("glsl") == 0) { // glsl included by other shaders diff --git a/kraken/KRPipeline.cpp b/kraken/KRPipeline.cpp index af2a56c..47db428 100644 --- a/kraken/KRPipeline.cpp +++ b/kraken/KRPipeline.cpp @@ -178,20 +178,13 @@ KRPipeline::KRPipeline(KRContext& context, KRSurface& surface, const PipelineInf binding.descriptorType = static_cast(binding_reflect.descriptor_type); binding.descriptorCount = binding_reflect.count; binding.pImmutableSamplers = nullptr; - if (shader->getSubExtension().compare("vert") == 0) { - binding.stageFlags = VK_SHADER_STAGE_VERTEX_BIT; - } else if (shader->getSubExtension().compare("frag") == 0) { - binding.stageFlags = VK_SHADER_STAGE_FRAGMENT_BIT; - } else { - // TODO - Error handling, support more stages - // Should probably make a lookup table for mapping extensions to stages - } + binding.stageFlags = shader->getShaderStage(); } VkPipelineShaderStageCreateInfo& stageInfo = stages[stage_count++]; stageInfo.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO; - if (shader->getSubExtension().compare("vert") == 0) { - stageInfo.stage = VK_SHADER_STAGE_VERTEX_BIT; + stageInfo.stage = shader->getShaderStage(); + if (stageInfo.stage == VK_SHADER_STAGE_VERTEX_BIT) { for (uint32_t i = 0; i < reflection->input_variable_count; i++) { // TODO - We should have an interface to allow classes such as KRMesh to expose bindings @@ -215,8 +208,7 @@ KRPipeline::KRPipeline(KRContext& context, KRSurface& surface, const PipelineInf initPushConstantStage(ShaderStages::vertex, reflection); - } else if (shader->getSubExtension().compare("frag") == 0) { - stageInfo.stage = VK_SHADER_STAGE_FRAGMENT_BIT; + } else if (stageInfo.stage == VK_SHADER_STAGE_FRAGMENT_BIT) { initPushConstantStage(ShaderStages::fragment, reflection); } else { // failed! TODO - Error handling diff --git a/kraken/KRShader.cpp b/kraken/KRShader.cpp index 8511f2f..abe7003 100644 --- a/kraken/KRShader.cpp +++ b/kraken/KRShader.cpp @@ -32,11 +32,47 @@ #include "KRShader.h" #include "spirv_reflect.h" +VkShaderStageFlagBits getShaderStageFromExtension(const char* extension) +{ + if (strcmp(extension, "vert") == 0) { + return VK_SHADER_STAGE_VERTEX_BIT; + } else if (strcmp(extension, "frag") == 0) { + return VK_SHADER_STAGE_FRAGMENT_BIT; + } else if (strcmp(extension, "tesc") == 0) { + return VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT; + } else if (strcmp(extension, "tese") == 0) { + return VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT; + } else if (strcmp(extension, "geom") == 0) { + return VK_SHADER_STAGE_GEOMETRY_BIT; + } else if (strcmp(extension, "comp") == 0) { + return VK_SHADER_STAGE_COMPUTE_BIT; + } else if (strcmp(extension, "mesh") == 0) { + return VK_SHADER_STAGE_MESH_BIT_NV; + } else if (strcmp(extension, "task") == 0) { + return VK_SHADER_STAGE_TASK_BIT_NV; + } else if (strcmp(extension, "rgen") == 0) { + return VK_SHADER_STAGE_RAYGEN_BIT_KHR; + } else if (strcmp(extension, "rint") == 0) { + return VK_SHADER_STAGE_INTERSECTION_BIT_KHR; + } else if (strcmp(extension, "rahit") == 0) { + return VK_SHADER_STAGE_ANY_HIT_BIT_KHR; + } else if (strcmp(extension, "rchit") == 0) { + return VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR; + } else if (strcmp(extension, "rmiss") == 0) { + return VK_SHADER_STAGE_MISS_BIT_KHR; + } else if (strcmp(extension, "rmiss") == 0) { + return VK_SHADER_STAGE_CALLABLE_BIT_KHR; + } else { + return (VkShaderStageFlagBits)0; + } +} + KRShader::KRShader(KRContext& context, std::string name, std::string extension) : KRResource(context, name) { m_pData = new KRDataBlock(); m_extension = extension; m_subExtension = KRResource::GetFileExtension(name); + m_stage = getShaderStageFromExtension(m_subExtension.c_str()); m_reflectionValid = false; getReflection(); @@ -47,6 +83,7 @@ KRShader::KRShader(KRContext& context, std::string name, std::string extension, m_pData = data; m_extension = extension; m_subExtension = KRResource::GetFileExtension(name); + m_stage = getShaderStageFromExtension(m_subExtension.c_str()); m_reflectionValid = false; } @@ -143,3 +180,8 @@ const SpvReflectShaderModule* KRShader::getReflection() } return nullptr; } + +VkShaderStageFlagBits KRShader::getShaderStage() const +{ + return m_stage; +} \ No newline at end of file diff --git a/kraken/KRShader.h b/kraken/KRShader.h index bdbe2ec..4b0216b 100644 --- a/kraken/KRShader.h +++ b/kraken/KRShader.h @@ -37,6 +37,8 @@ #include "KRResource.h" #include "spirv_reflect.h" +VkShaderStageFlagBits getShaderStageFromExtension(const char* extension); + class KRShader : public KRResource { public: @@ -53,6 +55,7 @@ public: KRDataBlock* getData(); const SpvReflectShaderModule* getReflection(); + VkShaderStageFlagBits getShaderStage() const; private: @@ -64,4 +67,6 @@ private: void parseReflection(); void freeReflection(); + + VkShaderStageFlagBits m_stage; }; diff --git a/kraken/KRSourceManager.cpp b/kraken/KRSourceManager.cpp index 9d38ba2..0f099be 100644 --- a/kraken/KRSourceManager.cpp +++ b/kraken/KRSourceManager.cpp @@ -31,6 +31,7 @@ #include "KRSourceManager.h" #include "KREngine-common.h" +#include "KRShader.h" KRSourceManager::KRSourceManager(KRContext& context) : KRResourceManager(context) { @@ -76,20 +77,7 @@ void KRSourceManager::add(KRSource* source) KRResource* KRSourceManager::loadResource(const std::string& name, const std::string& extension, KRDataBlock* data) { - if (extension.compare("vert") == 0 || - extension.compare("frag") == 0 || - extension.compare("tesc") == 0 || - extension.compare("tese") == 0 || - extension.compare("geom") == 0 || - extension.compare("comp") == 0 || - extension.compare("mesh") == 0 || - extension.compare("task") == 0 || - extension.compare("rgen") == 0 || - extension.compare("rint") == 0 || - extension.compare("rahit") == 0 || - extension.compare("rchit") == 0 || - extension.compare("rmiss") == 0 || - extension.compare("rcall") == 0 || + if (getShaderStageFromExtension(extension.c_str()) != 0 || extension.compare("glsl") == 0 || extension.compare("options") == 0) { return load(name, extension, data); @@ -99,20 +87,7 @@ KRResource* KRSourceManager::loadResource(const std::string& name, const std::st KRResource* KRSourceManager::getResource(const std::string& name, const std::string& extension) { - if (extension.compare("vert") == 0 || - extension.compare("frag") == 0 || - extension.compare("tesc") == 0 || - extension.compare("tese") == 0 || - extension.compare("geom") == 0 || - extension.compare("comp") == 0 || - extension.compare("mesh") == 0 || - extension.compare("task") == 0 || - extension.compare("rgen") == 0 || - extension.compare("rint") == 0 || - extension.compare("rahit") == 0 || - extension.compare("rchit") == 0 || - extension.compare("rmiss") == 0 || - extension.compare("rcall") == 0 || + if (getShaderStageFromExtension(extension.c_str()) != 0 || extension.compare("glsl") == 0 || extension.compare("options") == 0) { return get(name, extension);