Centralized shader file sub-extension to shader stage mapping to KRShader.

This commit is contained in:
2022-09-08 23:58:24 -07:00
parent e695bca3f9
commit f20f7f73d6
5 changed files with 56 additions and 81 deletions

View File

@@ -337,47 +337,8 @@ KRResource* KRContext::loadResource(const std::string& file_name, KRDataBlock* d
} else if (extension.compare("spv") == 0) { } else if (extension.compare("spv") == 0) {
// SPIR-V shader binary // SPIR-V shader binary
resource = m_pShaderManager->load(name, extension, data); resource = m_pShaderManager->load(name, extension, data);
} else if (extension.compare("vert") == 0) { } else if (getShaderStageFromExtension(extension.c_str()) != 0) {
// vertex shader // Shader source
resource = m_pSourceManager->load(name, extension, data);
} else if (extension.compare("frag") == 0) {
// fragment shader
resource = m_pSourceManager->load(name, extension, data);
} else if (extension.compare("tesc") == 0) {
// tessellation control shader
resource = m_pSourceManager->load(name, extension, data);
} else if (extension.compare("tese") == 0) {
// tessellation evaluation shader
resource = m_pSourceManager->load(name, extension, data);
} else if (extension.compare("geom") == 0) {
// geometry shader
resource = m_pSourceManager->load(name, extension, data);
} else if (extension.compare("comp") == 0) {
// compute shader
resource = m_pSourceManager->load(name, extension, data);
} else if (extension.compare("mesh") == 0) {
// mesh shader
resource = m_pSourceManager->load(name, extension, data);
} else if (extension.compare("task") == 0) {
// task shader
resource = m_pSourceManager->load(name, extension, data);
} else if (extension.compare("rgen") == 0) {
// ray generation shader
resource = m_pSourceManager->load(name, extension, data);
} else if (extension.compare("rint") == 0) {
// ray intersection shader
resource = m_pSourceManager->load(name, extension, data);
} else if (extension.compare("rahit") == 0) {
// ray any hit shader
resource = m_pSourceManager->load(name, extension, data);
} else if (extension.compare("rchit") == 0) {
// ray closest hit shader
resource = m_pSourceManager->load(name, extension, data);
} else if (extension.compare("rmiss") == 0) {
// ray miss shader
resource = m_pSourceManager->load(name, extension, data);
} else if (extension.compare("rcall") == 0) {
// ray callable shader
resource = m_pSourceManager->load(name, extension, data); resource = m_pSourceManager->load(name, extension, data);
} else if (extension.compare("glsl") == 0) { } else if (extension.compare("glsl") == 0) {
// glsl included by other shaders // glsl included by other shaders

View File

@@ -178,20 +178,13 @@ KRPipeline::KRPipeline(KRContext& context, KRSurface& surface, const PipelineInf
binding.descriptorType = static_cast<VkDescriptorType>(binding_reflect.descriptor_type); binding.descriptorType = static_cast<VkDescriptorType>(binding_reflect.descriptor_type);
binding.descriptorCount = binding_reflect.count; binding.descriptorCount = binding_reflect.count;
binding.pImmutableSamplers = nullptr; binding.pImmutableSamplers = nullptr;
if (shader->getSubExtension().compare("vert") == 0) { binding.stageFlags = shader->getShaderStage();
binding.stageFlags = VK_SHADER_STAGE_VERTEX_BIT;
} else if (shader->getSubExtension().compare("frag") == 0) {
binding.stageFlags = VK_SHADER_STAGE_FRAGMENT_BIT;
} else {
// TODO - Error handling, support more stages
// Should probably make a lookup table for mapping extensions to stages
}
} }
VkPipelineShaderStageCreateInfo& stageInfo = stages[stage_count++]; VkPipelineShaderStageCreateInfo& stageInfo = stages[stage_count++];
stageInfo.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO; stageInfo.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
if (shader->getSubExtension().compare("vert") == 0) { stageInfo.stage = shader->getShaderStage();
stageInfo.stage = VK_SHADER_STAGE_VERTEX_BIT; if (stageInfo.stage == VK_SHADER_STAGE_VERTEX_BIT) {
for (uint32_t i = 0; i < reflection->input_variable_count; i++) { for (uint32_t i = 0; i < reflection->input_variable_count; i++) {
// TODO - We should have an interface to allow classes such as KRMesh to expose bindings // TODO - We should have an interface to allow classes such as KRMesh to expose bindings
@@ -215,8 +208,7 @@ KRPipeline::KRPipeline(KRContext& context, KRSurface& surface, const PipelineInf
initPushConstantStage(ShaderStages::vertex, reflection); initPushConstantStage(ShaderStages::vertex, reflection);
} else if (shader->getSubExtension().compare("frag") == 0) { } else if (stageInfo.stage == VK_SHADER_STAGE_FRAGMENT_BIT) {
stageInfo.stage = VK_SHADER_STAGE_FRAGMENT_BIT;
initPushConstantStage(ShaderStages::fragment, reflection); initPushConstantStage(ShaderStages::fragment, reflection);
} else { } else {
// failed! TODO - Error handling // failed! TODO - Error handling

View File

@@ -32,11 +32,47 @@
#include "KRShader.h" #include "KRShader.h"
#include "spirv_reflect.h" #include "spirv_reflect.h"
VkShaderStageFlagBits getShaderStageFromExtension(const char* extension)
{
if (strcmp(extension, "vert") == 0) {
return VK_SHADER_STAGE_VERTEX_BIT;
} else if (strcmp(extension, "frag") == 0) {
return VK_SHADER_STAGE_FRAGMENT_BIT;
} else if (strcmp(extension, "tesc") == 0) {
return VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT;
} else if (strcmp(extension, "tese") == 0) {
return VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT;
} else if (strcmp(extension, "geom") == 0) {
return VK_SHADER_STAGE_GEOMETRY_BIT;
} else if (strcmp(extension, "comp") == 0) {
return VK_SHADER_STAGE_COMPUTE_BIT;
} else if (strcmp(extension, "mesh") == 0) {
return VK_SHADER_STAGE_MESH_BIT_NV;
} else if (strcmp(extension, "task") == 0) {
return VK_SHADER_STAGE_TASK_BIT_NV;
} else if (strcmp(extension, "rgen") == 0) {
return VK_SHADER_STAGE_RAYGEN_BIT_KHR;
} else if (strcmp(extension, "rint") == 0) {
return VK_SHADER_STAGE_INTERSECTION_BIT_KHR;
} else if (strcmp(extension, "rahit") == 0) {
return VK_SHADER_STAGE_ANY_HIT_BIT_KHR;
} else if (strcmp(extension, "rchit") == 0) {
return VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR;
} else if (strcmp(extension, "rmiss") == 0) {
return VK_SHADER_STAGE_MISS_BIT_KHR;
} else if (strcmp(extension, "rmiss") == 0) {
return VK_SHADER_STAGE_CALLABLE_BIT_KHR;
} else {
return (VkShaderStageFlagBits)0;
}
}
KRShader::KRShader(KRContext& context, std::string name, std::string extension) : KRResource(context, name) KRShader::KRShader(KRContext& context, std::string name, std::string extension) : KRResource(context, name)
{ {
m_pData = new KRDataBlock(); m_pData = new KRDataBlock();
m_extension = extension; m_extension = extension;
m_subExtension = KRResource::GetFileExtension(name); m_subExtension = KRResource::GetFileExtension(name);
m_stage = getShaderStageFromExtension(m_subExtension.c_str());
m_reflectionValid = false; m_reflectionValid = false;
getReflection(); getReflection();
@@ -47,6 +83,7 @@ KRShader::KRShader(KRContext& context, std::string name, std::string extension,
m_pData = data; m_pData = data;
m_extension = extension; m_extension = extension;
m_subExtension = KRResource::GetFileExtension(name); m_subExtension = KRResource::GetFileExtension(name);
m_stage = getShaderStageFromExtension(m_subExtension.c_str());
m_reflectionValid = false; m_reflectionValid = false;
} }
@@ -143,3 +180,8 @@ const SpvReflectShaderModule* KRShader::getReflection()
} }
return nullptr; return nullptr;
} }
VkShaderStageFlagBits KRShader::getShaderStage() const
{
return m_stage;
}

View File

@@ -37,6 +37,8 @@
#include "KRResource.h" #include "KRResource.h"
#include "spirv_reflect.h" #include "spirv_reflect.h"
VkShaderStageFlagBits getShaderStageFromExtension(const char* extension);
class KRShader : public KRResource class KRShader : public KRResource
{ {
public: public:
@@ -53,6 +55,7 @@ public:
KRDataBlock* getData(); KRDataBlock* getData();
const SpvReflectShaderModule* getReflection(); const SpvReflectShaderModule* getReflection();
VkShaderStageFlagBits getShaderStage() const;
private: private:
@@ -64,4 +67,6 @@ private:
void parseReflection(); void parseReflection();
void freeReflection(); void freeReflection();
VkShaderStageFlagBits m_stage;
}; };

View File

@@ -31,6 +31,7 @@
#include "KRSourceManager.h" #include "KRSourceManager.h"
#include "KREngine-common.h" #include "KREngine-common.h"
#include "KRShader.h"
KRSourceManager::KRSourceManager(KRContext& context) : KRResourceManager(context) KRSourceManager::KRSourceManager(KRContext& context) : KRResourceManager(context)
{ {
@@ -76,20 +77,7 @@ void KRSourceManager::add(KRSource* source)
KRResource* KRSourceManager::loadResource(const std::string& name, const std::string& extension, KRDataBlock* data) KRResource* KRSourceManager::loadResource(const std::string& name, const std::string& extension, KRDataBlock* data)
{ {
if (extension.compare("vert") == 0 || if (getShaderStageFromExtension(extension.c_str()) != 0 ||
extension.compare("frag") == 0 ||
extension.compare("tesc") == 0 ||
extension.compare("tese") == 0 ||
extension.compare("geom") == 0 ||
extension.compare("comp") == 0 ||
extension.compare("mesh") == 0 ||
extension.compare("task") == 0 ||
extension.compare("rgen") == 0 ||
extension.compare("rint") == 0 ||
extension.compare("rahit") == 0 ||
extension.compare("rchit") == 0 ||
extension.compare("rmiss") == 0 ||
extension.compare("rcall") == 0 ||
extension.compare("glsl") == 0 || extension.compare("glsl") == 0 ||
extension.compare("options") == 0) { extension.compare("options") == 0) {
return load(name, extension, data); return load(name, extension, data);
@@ -99,20 +87,7 @@ KRResource* KRSourceManager::loadResource(const std::string& name, const std::st
KRResource* KRSourceManager::getResource(const std::string& name, const std::string& extension) KRResource* KRSourceManager::getResource(const std::string& name, const std::string& extension)
{ {
if (extension.compare("vert") == 0 || if (getShaderStageFromExtension(extension.c_str()) != 0 ||
extension.compare("frag") == 0 ||
extension.compare("tesc") == 0 ||
extension.compare("tese") == 0 ||
extension.compare("geom") == 0 ||
extension.compare("comp") == 0 ||
extension.compare("mesh") == 0 ||
extension.compare("task") == 0 ||
extension.compare("rgen") == 0 ||
extension.compare("rint") == 0 ||
extension.compare("rahit") == 0 ||
extension.compare("rchit") == 0 ||
extension.compare("rmiss") == 0 ||
extension.compare("rcall") == 0 ||
extension.compare("glsl") == 0 || extension.compare("glsl") == 0 ||
extension.compare("options") == 0) { extension.compare("options") == 0) {
return get(name, extension); return get(name, extension);