#!/usr/bin/env python3
"""Tiny in-order issue-width simulator for the multi-issue lab."""

from __future__ import annotations

import argparse
from dataclasses import dataclass
from pathlib import Path
import re


LOAD_OPS = {"LW"}
STORE_OPS = {"SW"}
MEM_OPS = LOAD_OPS | STORE_OPS
BRANCH_OPS = {"BEQ", "BNE", "BLT", "BGE", "JAL", "JALR"}
NO_DST = STORE_OPS | BRANCH_OPS | {"NOP"}
REGISTER_RE = re.compile(r"\bX(?:[0-9]|[12][0-9]|3[01])\b")


@dataclass(frozen=True)
class Instruction:
    line_no: int
    text: str
    op: str
    dst: str | None
    srcs: tuple[str, ...]
    is_mem: bool
    is_branch: bool


@dataclass(frozen=True)
class PairResult:
    allowed: bool
    reason: str = "allowed"


def is_register(token: str) -> bool:
    return REGISTER_RE.fullmatch(token.upper()) is not None


def clean_token(token: str) -> str:
    return token.strip().rstrip(",").upper()


def registers_in(tokens: list[str]) -> list[str]:
    registers: list[str] = []
    for token in tokens:
        registers.extend(REGISTER_RE.findall(token.upper()))
    return registers


def parse_instruction(line: str, line_no: int) -> Instruction | None:
    text = line.split("#", 1)[0].strip()
    if not text:
        return None

    parts = [clean_token(part) for part in text.replace(",", " ").split()]
    op = parts[0]
    operands = parts[1:]

    dst = None
    srcs: list[str] = []

    if op == "NOP":
        pass
    elif op in STORE_OPS:
        srcs = registers_in(operands)
    elif op in BRANCH_OPS:
        srcs = registers_in(operands)
    else:
        if operands:
            dst = operands[0]
        srcs = registers_in(operands[1:])

    if op in NO_DST or (dst is not None and not is_register(dst)):
        dst = None

    return Instruction(
        line_no=line_no,
        text=text,
        op=op,
        dst=dst,
        srcs=tuple(srcs),
        is_mem=op in MEM_OPS,
        is_branch=op in BRANCH_OPS,
    )


def load_trace(path: Path) -> list[Instruction]:
    instructions: list[Instruction] = []
    for line_no, line in enumerate(path.read_text(encoding="utf-8").splitlines(), 1):
        inst = parse_instruction(line, line_no)
        if inst is not None:
            instructions.append(inst)
    return instructions


def can_pair(inst1: Instruction, inst2: Instruction) -> PairResult:
    if inst1.is_branch or inst2.is_branch:
        return PairResult(False, "branch")

    if inst1.is_mem and inst2.is_mem:
        return PairResult(False, "structural")

    if inst1.dst is not None and inst1.dst in inst2.srcs:
        return PairResult(False, "raw")

    return PairResult(True)


def simulate(instructions: list[Instruction], width: int) -> dict[str, float | int]:
    if width not in {1, 2}:
        raise ValueError("This lab simulator supports only --width 1 and --width 2.")

    pc = 0
    cycles = 0
    blocked_pairs = 0
    raw_blocks = 0
    structural_blocks = 0
    branch_blocks = 0

    while pc < len(instructions):
        cycles += 1

        if width == 1 or pc + 1 >= len(instructions):
            pc += 1
            continue

        result = can_pair(instructions[pc], instructions[pc + 1])
        if result.allowed:
            pc += 2
            continue

        blocked_pairs += 1
        if result.reason == "raw":
            raw_blocks += 1
        elif result.reason == "structural":
            structural_blocks += 1
        elif result.reason == "branch":
            branch_blocks += 1
        pc += 1

    instruction_count = len(instructions)
    unused_slots = cycles * width - instruction_count
    ipc = instruction_count / cycles if cycles else 0.0

    return {
        "instructions": instruction_count,
        "cycles": cycles,
        "ipc": ipc,
        "unused_slots": unused_slots,
        "blocked_pairs": blocked_pairs,
        "raw_blocks": raw_blocks,
        "structural_blocks": structural_blocks,
        "branch_blocks": branch_blocks,
    }


def main() -> int:
    parser = argparse.ArgumentParser(
        description="Run the simple in-order multi-issue lab simulator."
    )
    parser.add_argument("trace", type=Path, help="Path to a .trace file")
    parser.add_argument("--width", type=int, choices=(1, 2), required=True)
    args = parser.parse_args()

    instructions = load_trace(args.trace)
    stats = simulate(instructions, args.width)

    print(f"Trace: {args.trace}")
    print(f"Issue width: {args.width}")
    print(f"Instructions: {stats['instructions']}")
    print(f"Cycles: {stats['cycles']}")
    print(f"IPC: {stats['ipc']:.2f}")
    print(f"Unused issue slots: {stats['unused_slots']}")
    print(f"Blocked pairs: {stats['blocked_pairs']}")
    print(f"Blocked by RAW: {stats['raw_blocks']}")
    print(f"Blocked by structural limit: {stats['structural_blocks']}")
    print(f"Blocked by branch rule: {stats['branch_blocks']}")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
