Moved ShaderStage to KRShader and expanded to include all stages.
Added mapping functions, getShaderStageFromExtension and getShaderStageFlagBitsFromShaderStage. KRShader::m_stage is now typed as ShaderStage.
This commit is contained in:
@@ -337,7 +337,7 @@ KRResource* KRContext::loadResource(const std::string& file_name, KRDataBlock* d
|
||||
} else if (extension.compare("spv") == 0) {
|
||||
// SPIR-V shader binary
|
||||
resource = m_pShaderManager->load(name, extension, data);
|
||||
} else if (getShaderStageFromExtension(extension.c_str()) != 0) {
|
||||
} else if (getShaderStageFromExtension(extension.c_str()) != ShaderStage::Invalid) {
|
||||
// Shader source
|
||||
resource = m_pSourceManager->load(name, extension, data);
|
||||
} else if (extension.compare("glsl") == 0) {
|
||||
|
||||
@@ -178,12 +178,12 @@ KRPipeline::KRPipeline(KRContext& context, KRSurface& surface, const PipelineInf
|
||||
binding.descriptorType = static_cast<VkDescriptorType>(binding_reflect.descriptor_type);
|
||||
binding.descriptorCount = binding_reflect.count;
|
||||
binding.pImmutableSamplers = nullptr;
|
||||
binding.stageFlags = shader->getShaderStage();
|
||||
binding.stageFlags = shader->getShaderStageFlagBits();
|
||||
}
|
||||
|
||||
VkPipelineShaderStageCreateInfo& stageInfo = stages[stage_count++];
|
||||
stageInfo.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
|
||||
stageInfo.stage = shader->getShaderStage();
|
||||
stageInfo.stage = shader->getShaderStageFlagBits();
|
||||
if (stageInfo.stage == VK_SHADER_STAGE_VERTEX_BIT) {
|
||||
|
||||
for (uint32_t i = 0; i < reflection->input_variable_count; i++) {
|
||||
@@ -206,10 +206,10 @@ KRPipeline::KRPipeline(KRContext& context, KRSurface& surface, const PipelineInf
|
||||
}
|
||||
}
|
||||
|
||||
initPushConstantStage(ShaderStages::vertex, reflection);
|
||||
initPushConstantStage(ShaderStage::vert, reflection);
|
||||
|
||||
} else if (stageInfo.stage == VK_SHADER_STAGE_FRAGMENT_BIT) {
|
||||
initPushConstantStage(ShaderStages::fragment, reflection);
|
||||
initPushConstantStage(ShaderStage::frag, reflection);
|
||||
} else {
|
||||
// failed! TODO - Error handling
|
||||
}
|
||||
@@ -395,22 +395,7 @@ KRPipeline::KRPipeline(KRContext& context, KRSurface& surface, const PipelineInf
|
||||
VkPushConstantRange push_constant{};
|
||||
push_constant.offset = 0;
|
||||
push_constant.size = pushConstants.bufferSize;
|
||||
|
||||
switch (static_cast<ShaderStages>(iStage)) {
|
||||
case ShaderStages::vertex:
|
||||
push_constant.stageFlags = VK_SHADER_STAGE_VERTEX_BIT;
|
||||
break;
|
||||
case ShaderStages::fragment:
|
||||
push_constant.stageFlags = VK_SHADER_STAGE_FRAGMENT_BIT;
|
||||
break;
|
||||
case ShaderStages::geometry:
|
||||
push_constant.stageFlags = VK_SHADER_STAGE_GEOMETRY_BIT;
|
||||
break;
|
||||
case ShaderStages::compute:
|
||||
push_constant.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
|
||||
break;
|
||||
}
|
||||
|
||||
push_constant.stageFlags = getShaderStageFlagBitsFromShaderStage(static_cast<ShaderStage>(iStage));
|
||||
pushConstantsLayoutInfo.pPushConstantRanges = &push_constant;
|
||||
pushConstantsLayoutInfo.pushConstantRangeCount = 1;
|
||||
|
||||
@@ -511,7 +496,7 @@ KRPipeline::~KRPipeline()
|
||||
}
|
||||
}
|
||||
|
||||
void KRPipeline::initPushConstantStage(ShaderStages stage, const SpvReflectShaderModule* reflection)
|
||||
void KRPipeline::initPushConstantStage(ShaderStage stage, const SpvReflectShaderModule* reflection)
|
||||
{
|
||||
PushConstantStageInfo& pushConstants = m_pushConstants[static_cast<int>(stage)];
|
||||
for (int i = 0; i < reflection->push_constant_block_count; i++) {
|
||||
|
||||
@@ -37,6 +37,7 @@
|
||||
#include "KRNode.h"
|
||||
#include "KRViewport.h"
|
||||
#include "KRMesh.h"
|
||||
#include "KRShader.h"
|
||||
|
||||
class KRShader;
|
||||
class KRSurface;
|
||||
@@ -292,17 +293,6 @@ public:
|
||||
|
||||
static const size_t kPushConstantCount = static_cast<size_t>(PushConstant::NUM_PUSH_CONSTANTS);
|
||||
|
||||
enum class ShaderStages : uint8_t
|
||||
{
|
||||
vertex = 0,
|
||||
fragment,
|
||||
geometry,
|
||||
compute,
|
||||
shaderStageCount
|
||||
};
|
||||
|
||||
static const size_t kShaderStageCount = static_cast<size_t>(ShaderStages::shaderStageCount);
|
||||
|
||||
bool hasPushConstant(PushConstant location) const;
|
||||
void setPushConstant(PushConstant location, float value);
|
||||
void setPushConstant(PushConstant location, int value);
|
||||
@@ -325,7 +315,7 @@ private:
|
||||
int bufferSize;
|
||||
VkPipelineLayout layout;
|
||||
};
|
||||
PushConstantStageInfo m_pushConstants[static_cast<size_t>(ShaderStages::shaderStageCount)];
|
||||
PushConstantStageInfo m_pushConstants[static_cast<size_t>(ShaderStage::ShaderStageCount)];
|
||||
|
||||
char m_szKey[256];
|
||||
|
||||
@@ -333,5 +323,5 @@ private:
|
||||
VkPipelineLayout m_pipelineLayout;
|
||||
VkPipeline m_graphicsPipeline;
|
||||
|
||||
void initPushConstantStage(ShaderStages stage, const SpvReflectShaderModule* reflection);
|
||||
void initPushConstantStage(ShaderStage stage, const SpvReflectShaderModule* reflection);
|
||||
};
|
||||
|
||||
@@ -32,41 +32,76 @@
|
||||
#include "KRShader.h"
|
||||
#include "spirv_reflect.h"
|
||||
|
||||
VkShaderStageFlagBits getShaderStageFromExtension(const char* extension)
|
||||
ShaderStage getShaderStageFromExtension(const char* extension)
|
||||
{
|
||||
if (strcmp(extension, "vert") == 0) {
|
||||
return VK_SHADER_STAGE_VERTEX_BIT;
|
||||
return ShaderStage::vert;
|
||||
} else if (strcmp(extension, "frag") == 0) {
|
||||
return VK_SHADER_STAGE_FRAGMENT_BIT;
|
||||
return ShaderStage::frag;
|
||||
} else if (strcmp(extension, "tesc") == 0) {
|
||||
return VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT;
|
||||
return ShaderStage::tesc;
|
||||
} else if (strcmp(extension, "tese") == 0) {
|
||||
return VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT;
|
||||
return ShaderStage::tese;
|
||||
} else if (strcmp(extension, "geom") == 0) {
|
||||
return VK_SHADER_STAGE_GEOMETRY_BIT;
|
||||
return ShaderStage::geom;
|
||||
} else if (strcmp(extension, "comp") == 0) {
|
||||
return VK_SHADER_STAGE_COMPUTE_BIT;
|
||||
return ShaderStage::comp;
|
||||
} else if (strcmp(extension, "mesh") == 0) {
|
||||
return VK_SHADER_STAGE_MESH_BIT_NV;
|
||||
return ShaderStage::mesh;
|
||||
} else if (strcmp(extension, "task") == 0) {
|
||||
return VK_SHADER_STAGE_TASK_BIT_NV;
|
||||
return ShaderStage::task;
|
||||
} else if (strcmp(extension, "rgen") == 0) {
|
||||
return VK_SHADER_STAGE_RAYGEN_BIT_KHR;
|
||||
return ShaderStage::rgen;
|
||||
} else if (strcmp(extension, "rint") == 0) {
|
||||
return VK_SHADER_STAGE_INTERSECTION_BIT_KHR;
|
||||
return ShaderStage::rint;
|
||||
} else if (strcmp(extension, "rahit") == 0) {
|
||||
return VK_SHADER_STAGE_ANY_HIT_BIT_KHR;
|
||||
return ShaderStage::rahit;
|
||||
} else if (strcmp(extension, "rchit") == 0) {
|
||||
return VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR;
|
||||
return ShaderStage::rchit;
|
||||
} else if (strcmp(extension, "rmiss") == 0) {
|
||||
return VK_SHADER_STAGE_MISS_BIT_KHR;
|
||||
} else if (strcmp(extension, "rmiss") == 0) {
|
||||
return VK_SHADER_STAGE_CALLABLE_BIT_KHR;
|
||||
return ShaderStage::rmiss;
|
||||
} else if (strcmp(extension, "rcall") == 0) {
|
||||
return ShaderStage::rcall;
|
||||
} else {
|
||||
return (VkShaderStageFlagBits)0;
|
||||
return ShaderStage::Invalid;
|
||||
}
|
||||
}
|
||||
|
||||
VkShaderStageFlagBits getShaderStageFlagBitsFromShaderStage(ShaderStage stage)
|
||||
{
|
||||
switch (stage) {
|
||||
case ShaderStage::vert:
|
||||
return VK_SHADER_STAGE_VERTEX_BIT;
|
||||
case ShaderStage::frag:
|
||||
return VK_SHADER_STAGE_FRAGMENT_BIT;
|
||||
case ShaderStage::tesc:
|
||||
return VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT;
|
||||
case ShaderStage::tese:
|
||||
return VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT;
|
||||
case ShaderStage::geom:
|
||||
return VK_SHADER_STAGE_GEOMETRY_BIT;
|
||||
case ShaderStage::comp:
|
||||
return VK_SHADER_STAGE_COMPUTE_BIT;
|
||||
case ShaderStage::mesh:
|
||||
return VK_SHADER_STAGE_MESH_BIT_NV;
|
||||
case ShaderStage::task:
|
||||
return VK_SHADER_STAGE_TASK_BIT_NV;
|
||||
case ShaderStage::rgen:
|
||||
return VK_SHADER_STAGE_RAYGEN_BIT_KHR;
|
||||
case ShaderStage::rint:
|
||||
return VK_SHADER_STAGE_INTERSECTION_BIT_KHR;
|
||||
case ShaderStage::rahit:
|
||||
return VK_SHADER_STAGE_ANY_HIT_BIT_KHR;
|
||||
case ShaderStage::rchit:
|
||||
return VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR;
|
||||
case ShaderStage::rmiss:
|
||||
return VK_SHADER_STAGE_MISS_BIT_KHR;
|
||||
case ShaderStage::rcall:
|
||||
return VK_SHADER_STAGE_CALLABLE_BIT_KHR;
|
||||
}
|
||||
return (VkShaderStageFlagBits)0;
|
||||
}
|
||||
|
||||
KRShader::KRShader(KRContext& context, std::string name, std::string extension) : KRResource(context, name)
|
||||
{
|
||||
m_pData = new KRDataBlock();
|
||||
@@ -181,7 +216,12 @@ const SpvReflectShaderModule* KRShader::getReflection()
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
VkShaderStageFlagBits KRShader::getShaderStage() const
|
||||
ShaderStage KRShader::getShaderStage() const
|
||||
{
|
||||
return m_stage;
|
||||
}
|
||||
}
|
||||
|
||||
VkShaderStageFlagBits KRShader::getShaderStageFlagBits() const
|
||||
{
|
||||
return getShaderStageFlagBitsFromShaderStage(m_stage);
|
||||
}
|
||||
|
||||
@@ -37,7 +37,30 @@
|
||||
#include "KRResource.h"
|
||||
#include "spirv_reflect.h"
|
||||
|
||||
VkShaderStageFlagBits getShaderStageFromExtension(const char* extension);
|
||||
enum class ShaderStage : uint8_t
|
||||
{
|
||||
vert = 0,
|
||||
frag,
|
||||
tesc,
|
||||
tese,
|
||||
geom,
|
||||
comp,
|
||||
mesh,
|
||||
task,
|
||||
rgen,
|
||||
rint,
|
||||
rahit,
|
||||
rchit,
|
||||
rmiss,
|
||||
rcall,
|
||||
ShaderStageCount,
|
||||
Invalid = 0xff
|
||||
};
|
||||
|
||||
ShaderStage getShaderStageFromExtension(const char* extension);
|
||||
VkShaderStageFlagBits getShaderStageFlagBitsFromShaderStage(ShaderStage stage);
|
||||
|
||||
static const size_t kShaderStageCount = static_cast<size_t>(ShaderStage::ShaderStageCount);
|
||||
|
||||
class KRShader : public KRResource
|
||||
{
|
||||
@@ -55,7 +78,9 @@ public:
|
||||
|
||||
KRDataBlock* getData();
|
||||
const SpvReflectShaderModule* getReflection();
|
||||
VkShaderStageFlagBits getShaderStage() const;
|
||||
ShaderStage getShaderStage() const;
|
||||
VkShaderStageFlagBits getShaderStageFlagBits() const;
|
||||
|
||||
|
||||
private:
|
||||
|
||||
@@ -68,5 +93,5 @@ private:
|
||||
void parseReflection();
|
||||
void freeReflection();
|
||||
|
||||
VkShaderStageFlagBits m_stage;
|
||||
ShaderStage m_stage;
|
||||
};
|
||||
|
||||
@@ -77,7 +77,7 @@ void KRSourceManager::add(KRSource* source)
|
||||
|
||||
KRResource* KRSourceManager::loadResource(const std::string& name, const std::string& extension, KRDataBlock* data)
|
||||
{
|
||||
if (getShaderStageFromExtension(extension.c_str()) != 0 ||
|
||||
if (getShaderStageFromExtension(extension.c_str()) != ShaderStage::Invalid ||
|
||||
extension.compare("glsl") == 0 ||
|
||||
extension.compare("options") == 0) {
|
||||
return load(name, extension, data);
|
||||
@@ -87,7 +87,7 @@ KRResource* KRSourceManager::loadResource(const std::string& name, const std::st
|
||||
|
||||
KRResource* KRSourceManager::getResource(const std::string& name, const std::string& extension)
|
||||
{
|
||||
if (getShaderStageFromExtension(extension.c_str()) != 0 ||
|
||||
if (getShaderStageFromExtension(extension.c_str()) != ShaderStage::Invalid ||
|
||||
extension.compare("glsl") == 0 ||
|
||||
extension.compare("options") == 0) {
|
||||
return get(name, extension);
|
||||
|
||||
Reference in New Issue
Block a user