diff --git a/src/video_core/shader/control_flow.cpp b/src/video_core/shader/control_flow.cpp index 43d965f2fc..6dd09eb912 100644 --- a/src/video_core/shader/control_flow.cpp +++ b/src/video_core/shader/control_flow.cpp @@ -26,17 +26,29 @@ using Tegra::Shader::OpCode; constexpr s32 unassigned_branch = -2; +enum class JumpLabel : u32 { + SSYClass = 0, + PBKClass = 1, +}; + +struct JumpItem { + JumpLabel type; + u32 address; + + bool operator==(const JumpItem& other) const { + return std::tie(type, address) == std::tie(other.type, other.address); + } +}; + struct Query { u32 address{}; - std::stack ssy_stack{}; - std::stack pbk_stack{}; + std::stack stack{}; }; struct BlockStack { BlockStack() = default; - explicit BlockStack(const Query& q) : ssy_stack{q.ssy_stack}, pbk_stack{q.pbk_stack} {} - std::stack ssy_stack{}; - std::stack pbk_stack{}; + explicit BlockStack(const Query& q) : stack{q.stack} {} + std::stack stack{}; }; template @@ -77,8 +89,7 @@ struct CFGRebuildState { std::list queries; std::unordered_map registered; std::set labels; - std::map ssy_labels; - std::map pbk_labels; + std::map jump_labels; std::unordered_map stacks; ASTManager* manager{}; }; @@ -411,13 +422,15 @@ std::pair ParseCode(CFGRebuildState& state, u32 address) case OpCode::Id::SSY: { const u32 target = offset + instr.bra.GetBranchTarget(); insert_label(state, target); - state.ssy_labels.emplace(offset, target); + JumpItem it = {JumpLabel::SSYClass, target}; + state.jump_labels.emplace(offset, it); break; } case OpCode::Id::PBK: { const u32 target = offset + instr.bra.GetBranchTarget(); insert_label(state, target); - state.pbk_labels.emplace(offset, target); + JumpItem it = {JumpLabel::PBKClass, target}; + state.jump_labels.emplace(offset, it); break; } case OpCode::Id::BRX: { @@ -513,7 +526,7 @@ bool TryInspectAddress(CFGRebuildState& state) { } bool TryQuery(CFGRebuildState& state) { - const auto gather_labels = [](std::stack& cc, std::map& labels, + const auto gather_labels = [](std::stack& cc, std::map& labels, BlockInfo& block) { auto gather_start = labels.lower_bound(block.start); const auto gather_end = labels.upper_bound(block.end); @@ -522,6 +535,19 @@ bool TryQuery(CFGRebuildState& state) { ++gather_start; } }; + const auto pop_labels = [](JumpLabel type, SingleBranch* branch, Query& query) -> bool { + while (!query.stack.empty() && query.stack.top().type != type) { + query.stack.pop(); + } + if (query.stack.empty()) { + return false; + } + if (branch->address == unassigned_branch) { + branch->address = query.stack.top().address; + } + query.stack.pop(); + return true; + }; if (state.queries.empty()) { return false; } @@ -534,8 +560,7 @@ bool TryQuery(CFGRebuildState& state) { // consumes a label. Schedule new queries accordingly if (block.visited) { BlockStack& stack = state.stacks[q.address]; - const bool all_okay = (stack.ssy_stack.empty() || q.ssy_stack == stack.ssy_stack) && - (stack.pbk_stack.empty() || q.pbk_stack == stack.pbk_stack); + const bool all_okay = (stack.stack.empty() || q.stack == stack.stack); state.queries.pop_front(); return all_okay; } @@ -544,8 +569,7 @@ bool TryQuery(CFGRebuildState& state) { Query q2(q); state.queries.pop_front(); - gather_labels(q2.ssy_stack, state.ssy_labels, block); - gather_labels(q2.pbk_stack, state.pbk_labels, block); + gather_labels(q2.stack, state.jump_labels, block); if (std::holds_alternative(*block.branch)) { auto* branch = std::get_if(block.branch.get()); if (!branch->condition.IsUnconditional()) { @@ -555,16 +579,10 @@ bool TryQuery(CFGRebuildState& state) { auto& conditional_query = state.queries.emplace_back(q2); if (branch->is_sync) { - if (branch->address == unassigned_branch) { - branch->address = conditional_query.ssy_stack.top(); - } - conditional_query.ssy_stack.pop(); + pop_labels(JumpLabel::SSYClass, branch, conditional_query); } if (branch->is_brk) { - if (branch->address == unassigned_branch) { - branch->address = conditional_query.pbk_stack.top(); - } - conditional_query.pbk_stack.pop(); + pop_labels(JumpLabel::PBKClass, branch, conditional_query); } conditional_query.address = branch->address; return true; @@ -675,7 +693,7 @@ std::unique_ptr ScanFlow(const ProgramCode& program_code, if (settings.depth != CompileDepth::FlowStack) { // Decompile Stacks - state.queries.push_back(Query{state.start, {}, {}}); + state.queries.push_back(Query{state.start, {}}); decompiled = true; while (!state.queries.empty()) { if (!TryQuery(state)) {