Moved ShaderStage to KRShader and expanded to include all stages.

Added mapping functions, getShaderStageFromExtension and getShaderStageFlagBitsFromShaderStage.
KRShader::m_stage is now typed as ShaderStage.
This commit is contained in:
2022-09-09 00:36:22 -07:00
parent f20f7f73d6
commit aeaed68efb
6 changed files with 99 additions and 59 deletions

View File

@@ -337,7 +337,7 @@ KRResource* KRContext::loadResource(const std::string& file_name, KRDataBlock* d
} else if (extension.compare("spv") == 0) {
// SPIR-V shader binary
resource = m_pShaderManager->load(name, extension, data);
} else if (getShaderStageFromExtension(extension.c_str()) != 0) {
} else if (getShaderStageFromExtension(extension.c_str()) != ShaderStage::Invalid) {
// Shader source
resource = m_pSourceManager->load(name, extension, data);
} else if (extension.compare("glsl") == 0) {

View File

@@ -178,12 +178,12 @@ KRPipeline::KRPipeline(KRContext& context, KRSurface& surface, const PipelineInf
binding.descriptorType = static_cast<VkDescriptorType>(binding_reflect.descriptor_type);
binding.descriptorCount = binding_reflect.count;
binding.pImmutableSamplers = nullptr;
binding.stageFlags = shader->getShaderStage();
binding.stageFlags = shader->getShaderStageFlagBits();
}
VkPipelineShaderStageCreateInfo& stageInfo = stages[stage_count++];
stageInfo.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
stageInfo.stage = shader->getShaderStage();
stageInfo.stage = shader->getShaderStageFlagBits();
if (stageInfo.stage == VK_SHADER_STAGE_VERTEX_BIT) {
for (uint32_t i = 0; i < reflection->input_variable_count; i++) {
@@ -206,10 +206,10 @@ KRPipeline::KRPipeline(KRContext& context, KRSurface& surface, const PipelineInf
}
}
initPushConstantStage(ShaderStages::vertex, reflection);
initPushConstantStage(ShaderStage::vert, reflection);
} else if (stageInfo.stage == VK_SHADER_STAGE_FRAGMENT_BIT) {
initPushConstantStage(ShaderStages::fragment, reflection);
initPushConstantStage(ShaderStage::frag, reflection);
} else {
// failed! TODO - Error handling
}
@@ -395,22 +395,7 @@ KRPipeline::KRPipeline(KRContext& context, KRSurface& surface, const PipelineInf
VkPushConstantRange push_constant{};
push_constant.offset = 0;
push_constant.size = pushConstants.bufferSize;
switch (static_cast<ShaderStages>(iStage)) {
case ShaderStages::vertex:
push_constant.stageFlags = VK_SHADER_STAGE_VERTEX_BIT;
break;
case ShaderStages::fragment:
push_constant.stageFlags = VK_SHADER_STAGE_FRAGMENT_BIT;
break;
case ShaderStages::geometry:
push_constant.stageFlags = VK_SHADER_STAGE_GEOMETRY_BIT;
break;
case ShaderStages::compute:
push_constant.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
break;
}
push_constant.stageFlags = getShaderStageFlagBitsFromShaderStage(static_cast<ShaderStage>(iStage));
pushConstantsLayoutInfo.pPushConstantRanges = &push_constant;
pushConstantsLayoutInfo.pushConstantRangeCount = 1;
@@ -511,7 +496,7 @@ KRPipeline::~KRPipeline()
}
}
void KRPipeline::initPushConstantStage(ShaderStages stage, const SpvReflectShaderModule* reflection)
void KRPipeline::initPushConstantStage(ShaderStage stage, const SpvReflectShaderModule* reflection)
{
PushConstantStageInfo& pushConstants = m_pushConstants[static_cast<int>(stage)];
for (int i = 0; i < reflection->push_constant_block_count; i++) {

View File

@@ -37,6 +37,7 @@
#include "KRNode.h"
#include "KRViewport.h"
#include "KRMesh.h"
#include "KRShader.h"
class KRShader;
class KRSurface;
@@ -292,17 +293,6 @@ public:
static const size_t kPushConstantCount = static_cast<size_t>(PushConstant::NUM_PUSH_CONSTANTS);
enum class ShaderStages : uint8_t
{
vertex = 0,
fragment,
geometry,
compute,
shaderStageCount
};
static const size_t kShaderStageCount = static_cast<size_t>(ShaderStages::shaderStageCount);
bool hasPushConstant(PushConstant location) const;
void setPushConstant(PushConstant location, float value);
void setPushConstant(PushConstant location, int value);
@@ -325,7 +315,7 @@ private:
int bufferSize;
VkPipelineLayout layout;
};
PushConstantStageInfo m_pushConstants[static_cast<size_t>(ShaderStages::shaderStageCount)];
PushConstantStageInfo m_pushConstants[static_cast<size_t>(ShaderStage::ShaderStageCount)];
char m_szKey[256];
@@ -333,5 +323,5 @@ private:
VkPipelineLayout m_pipelineLayout;
VkPipeline m_graphicsPipeline;
void initPushConstantStage(ShaderStages stage, const SpvReflectShaderModule* reflection);
void initPushConstantStage(ShaderStage stage, const SpvReflectShaderModule* reflection);
};

View File

@@ -32,41 +32,76 @@
#include "KRShader.h"
#include "spirv_reflect.h"
VkShaderStageFlagBits getShaderStageFromExtension(const char* extension)
ShaderStage getShaderStageFromExtension(const char* extension)
{
if (strcmp(extension, "vert") == 0) {
return VK_SHADER_STAGE_VERTEX_BIT;
return ShaderStage::vert;
} else if (strcmp(extension, "frag") == 0) {
return VK_SHADER_STAGE_FRAGMENT_BIT;
return ShaderStage::frag;
} else if (strcmp(extension, "tesc") == 0) {
return VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT;
return ShaderStage::tesc;
} else if (strcmp(extension, "tese") == 0) {
return VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT;
return ShaderStage::tese;
} else if (strcmp(extension, "geom") == 0) {
return VK_SHADER_STAGE_GEOMETRY_BIT;
return ShaderStage::geom;
} else if (strcmp(extension, "comp") == 0) {
return VK_SHADER_STAGE_COMPUTE_BIT;
return ShaderStage::comp;
} else if (strcmp(extension, "mesh") == 0) {
return VK_SHADER_STAGE_MESH_BIT_NV;
return ShaderStage::mesh;
} else if (strcmp(extension, "task") == 0) {
return VK_SHADER_STAGE_TASK_BIT_NV;
return ShaderStage::task;
} else if (strcmp(extension, "rgen") == 0) {
return VK_SHADER_STAGE_RAYGEN_BIT_KHR;
return ShaderStage::rgen;
} else if (strcmp(extension, "rint") == 0) {
return VK_SHADER_STAGE_INTERSECTION_BIT_KHR;
return ShaderStage::rint;
} else if (strcmp(extension, "rahit") == 0) {
return VK_SHADER_STAGE_ANY_HIT_BIT_KHR;
return ShaderStage::rahit;
} else if (strcmp(extension, "rchit") == 0) {
return VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR;
return ShaderStage::rchit;
} else if (strcmp(extension, "rmiss") == 0) {
return VK_SHADER_STAGE_MISS_BIT_KHR;
} else if (strcmp(extension, "rmiss") == 0) {
return VK_SHADER_STAGE_CALLABLE_BIT_KHR;
return ShaderStage::rmiss;
} else if (strcmp(extension, "rcall") == 0) {
return ShaderStage::rcall;
} else {
return (VkShaderStageFlagBits)0;
return ShaderStage::Invalid;
}
}
VkShaderStageFlagBits getShaderStageFlagBitsFromShaderStage(ShaderStage stage)
{
switch (stage) {
case ShaderStage::vert:
return VK_SHADER_STAGE_VERTEX_BIT;
case ShaderStage::frag:
return VK_SHADER_STAGE_FRAGMENT_BIT;
case ShaderStage::tesc:
return VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT;
case ShaderStage::tese:
return VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT;
case ShaderStage::geom:
return VK_SHADER_STAGE_GEOMETRY_BIT;
case ShaderStage::comp:
return VK_SHADER_STAGE_COMPUTE_BIT;
case ShaderStage::mesh:
return VK_SHADER_STAGE_MESH_BIT_NV;
case ShaderStage::task:
return VK_SHADER_STAGE_TASK_BIT_NV;
case ShaderStage::rgen:
return VK_SHADER_STAGE_RAYGEN_BIT_KHR;
case ShaderStage::rint:
return VK_SHADER_STAGE_INTERSECTION_BIT_KHR;
case ShaderStage::rahit:
return VK_SHADER_STAGE_ANY_HIT_BIT_KHR;
case ShaderStage::rchit:
return VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR;
case ShaderStage::rmiss:
return VK_SHADER_STAGE_MISS_BIT_KHR;
case ShaderStage::rcall:
return VK_SHADER_STAGE_CALLABLE_BIT_KHR;
}
return (VkShaderStageFlagBits)0;
}
KRShader::KRShader(KRContext& context, std::string name, std::string extension) : KRResource(context, name)
{
m_pData = new KRDataBlock();
@@ -181,7 +216,12 @@ const SpvReflectShaderModule* KRShader::getReflection()
return nullptr;
}
VkShaderStageFlagBits KRShader::getShaderStage() const
ShaderStage KRShader::getShaderStage() const
{
return m_stage;
}
VkShaderStageFlagBits KRShader::getShaderStageFlagBits() const
{
return getShaderStageFlagBitsFromShaderStage(m_stage);
}

View File

@@ -37,7 +37,30 @@
#include "KRResource.h"
#include "spirv_reflect.h"
VkShaderStageFlagBits getShaderStageFromExtension(const char* extension);
enum class ShaderStage : uint8_t
{
vert = 0,
frag,
tesc,
tese,
geom,
comp,
mesh,
task,
rgen,
rint,
rahit,
rchit,
rmiss,
rcall,
ShaderStageCount,
Invalid = 0xff
};
ShaderStage getShaderStageFromExtension(const char* extension);
VkShaderStageFlagBits getShaderStageFlagBitsFromShaderStage(ShaderStage stage);
static const size_t kShaderStageCount = static_cast<size_t>(ShaderStage::ShaderStageCount);
class KRShader : public KRResource
{
@@ -55,7 +78,9 @@ public:
KRDataBlock* getData();
const SpvReflectShaderModule* getReflection();
VkShaderStageFlagBits getShaderStage() const;
ShaderStage getShaderStage() const;
VkShaderStageFlagBits getShaderStageFlagBits() const;
private:
@@ -68,5 +93,5 @@ private:
void parseReflection();
void freeReflection();
VkShaderStageFlagBits m_stage;
ShaderStage m_stage;
};

View File

@@ -77,7 +77,7 @@ void KRSourceManager::add(KRSource* source)
KRResource* KRSourceManager::loadResource(const std::string& name, const std::string& extension, KRDataBlock* data)
{
if (getShaderStageFromExtension(extension.c_str()) != 0 ||
if (getShaderStageFromExtension(extension.c_str()) != ShaderStage::Invalid ||
extension.compare("glsl") == 0 ||
extension.compare("options") == 0) {
return load(name, extension, data);
@@ -87,7 +87,7 @@ KRResource* KRSourceManager::loadResource(const std::string& name, const std::st
KRResource* KRSourceManager::getResource(const std::string& name, const std::string& extension)
{
if (getShaderStageFromExtension(extension.c_str()) != 0 ||
if (getShaderStageFromExtension(extension.c_str()) != ShaderStage::Invalid ||
extension.compare("glsl") == 0 ||
extension.compare("options") == 0) {
return get(name, extension);