arch refactor + backend ir start

This commit is contained in:
2026-06-06 21:00:39 -04:00
parent 0ac7c5cc02
commit 4587f687b9
22 changed files with 547 additions and 661 deletions
+12
View File
@@ -0,0 +1,12 @@
use crate::{
backend::{LinkedProgram, Program},
io::CompilerMsg,
};
pub mod x86_64;
pub trait Arch: Sized {
const NAME: &str;
type Asm;
fn compile(p: &Program<Self>) -> Result<LinkedProgram, CompilerMsg>;
}
+46
View File
@@ -0,0 +1,46 @@
use crate::{arch::x86_64::RegMode, backend::Symbol};
pub struct Asm {
pub instrs: Vec<Instr>,
}
#[derive(Clone, Copy)]
pub enum Instr {
Mov { dst: RegMode, src: RegImm },
Int { code: u8 },
Syscall,
Lea { dst: RegMode, sym: Symbol },
}
#[derive(Clone, Copy)]
pub enum RegImm {
Reg(RegMode),
Imm(u64),
}
impl From<RegMode> for RegImm {
fn from(value: RegMode) -> Self {
Self::Reg(value)
}
}
impl From<u64> for RegImm {
fn from(value: u64) -> Self {
Self::Imm(value)
}
}
pub fn mov(dst: RegMode, src: impl Into<RegImm>) -> Instr {
Instr::Mov {
dst,
src: src.into(),
}
}
pub fn lea(dst: RegMode, sym: Symbol) -> Instr {
Instr::Lea { dst, sym }
}
pub fn int(code: u8) -> Instr {
Instr::Int { code }
}
+141
View File
@@ -0,0 +1,141 @@
use super::*;
use crate::backend::{Addr, LinkedProgram, SymTable, Symbol};
pub struct Encoder {
pub data: Vec<u8>,
pub sym_tab: SymTable,
pub missing: Vec<(usize, Symbol)>,
}
pub fn encode_program(p: &Program<X86_64>) -> Result<LinkedProgram, CompilerMsg> {
let mut encoder = Encoder::new(p.sym_count());
p.encode_data(&mut encoder.data, &mut encoder.sym_tab);
for f in &p.funcs {
let addr = encoder.data.len();
encoder.sym_tab.insert(f.sym, addr as u64);
for instr in &f.instrs {
compile_instr(&mut encoder, instr)?;
}
}
for (pos, sym) in encoder.missing.drain(..) {
let addr = encoder
.sym_tab
.get(sym)
.ok_or(CompilerMsg::from(format!("unknown symbol {sym:?}")))?;
encoder.data[pos..pos + 4].copy_from_slice(&addr_offset(pos, addr))
}
Ok(LinkedProgram {
code: encoder.data,
entry: p.entry.and_then(|e| encoder.sym_tab.get(e)),
})
}
type BInstr = crate::backend::Instr<X86_64>;
fn compile_instr(encoder: &mut Encoder, instr: &BInstr) -> Result<(), CompilerMsg> {
match instr {
BInstr::Copy { dst, src } => todo!(),
BInstr::Asm(asm) => {
for i in &asm.instrs {
encoder.asm(*i)?;
}
}
}
Ok(())
}
impl Encoder {
// assembly
pub fn mov(&mut self, dst: RegMode, src: impl Into<RegImm>) -> Result<(), CompilerMsg> {
let src = src.into();
let width = dst.width;
if width == BitWidth::B16 {
self.data.push(0x66);
}
let dst8 = dst.gt8();
let b64 = width == BitWidth::B64;
let b8 = width == BitWidth::B8;
let src8 = if let RegImm::Reg(src) = src {
src.gt8()
} else {
false
};
// special 64-bit / register 4-7 indicator
if dst8 || src8 || b64 || (dst.gt4() && !dst.high) {
self.data
.push(0x40 | dst8 as u8 | ((b64 as u8) << 3) | ((src8 as u8) << 2));
}
match src {
RegImm::Reg(src) => {
if dst.width != src.width {
return Err("src and dst are not the same size".into());
}
self.data.push(0x88 | !b8 as u8);
let modrm = 0b11_000_000 | (src.base() << 3) | dst.base();
self.data.push(modrm);
}
RegImm::Imm(imm) => {
if imm > width.max() {
return Err("immediate cannot fit in register".into());
}
self.data.push(0xb0 | ((!b8 as u8) << 3) | dst.base());
self.data.extend(&imm.to_le_bytes()[..width.bytes()]);
}
}
Ok(())
}
pub fn lea(&mut self, dst: RegMode, sym: Symbol) {
self.data.extend([
0x48 | ((dst.gt8() as u8) << 2),
0x8d,
0x05 | (dst.base() << 3),
]);
let Some(addr) = self.sym_tab.get(sym) else {
let pos = self.data.len();
self.data.extend([0; 4]);
self.missing.push((pos, sym));
return;
};
self.data.extend(addr_offset(self.data.len(), addr));
}
pub fn int(&mut self, code: u8) {
self.data.extend([0xcd, code])
}
pub fn syscall(&mut self) {
self.data.extend([0x0f, 0x05])
}
pub fn asm(&mut self, instr: Instr) -> Result<(), CompilerMsg> {
match instr {
Instr::Mov { dst, src } => self.mov(dst, src)?,
Instr::Int { code } => self.int(code),
Instr::Syscall => self.syscall(),
Instr::Lea { dst, sym } => self.lea(dst, sym),
}
Ok(())
}
}
/// assumes the next instruction is directly after
fn addr_offset(pos: usize, addr: Addr) -> [u8; 4] {
let pos = (pos + 4) as i32;
let offset = addr as i32 - pos;
offset.to_le_bytes()
}
impl Encoder {
pub fn new(sym_count: usize) -> Self {
Self {
data: Default::default(),
sym_tab: SymTable::new(sym_count),
missing: Default::default(),
}
}
}
+25
View File
@@ -0,0 +1,25 @@
mod asm;
mod encode;
mod reg;
#[cfg(test)]
mod test;
use crate::{
arch::Arch,
backend::{LinkedProgram, Program},
io::CompilerMsg,
};
pub use asm::*;
pub use encode::*;
pub use reg::*;
pub struct X86_64;
impl Arch for X86_64 {
const NAME: &str = "x86_64";
type Asm = Asm;
fn compile(p: &Program<Self>) -> Result<LinkedProgram, CompilerMsg> {
encode_program(p)
}
}
+106
View File
@@ -0,0 +1,106 @@
#[derive(Clone, Copy)]
pub struct Reg(u8);
#[derive(Clone, Copy)]
pub struct RegMode {
pub reg: Reg,
pub width: BitWidth,
pub high: bool,
}
#[derive(Clone, Copy, PartialEq)]
pub enum BitWidth {
B64,
B32,
B16,
B8,
}
impl RegMode {
pub fn base(&self) -> u8 {
self.reg.0 & 0b111
}
/// checks if register is not one of the first 8 (0-7)
pub fn gt8(&self) -> bool {
self.reg.0 >= 0b1000
}
pub fn gt4(&self) -> bool {
self.reg.0 >= 0b0100
}
}
def_regs! {
0b0000 : rax eax ax al ah=spl,
0b0001 : rcx ecx cx cl ch=bpl,
0b0010 : rdx edx dx dl dh=sil,
0b0011 : rbx ebx bx bl bh=dil,
0b0100 : rsp esp sp spl,
0b0101 : rbp ebp bp bpl,
0b0110 : rsi esi si sil,
0b0111 : rdi edi di dil,
0b1000 : r8 r8d r8w r8b,
0b1001 : r9 r9d r9w r9b,
0b1010 : r10 r10d r10w r10b,
0b1011 : r11 r11d r11w r11b,
0b1100 : r12 r12d r12w r12b,
0b1101 : r13 r13d r13w r13b,
0b1110 : r14 r14d r14w r14b,
0b1111 : r15 r15d r15w r15b,
}
impl BitWidth {
pub const fn max(&self) -> u64 {
match self {
Self::B64 => u64::MAX,
Self::B32 => u32::MAX as u64,
Self::B16 => u16::MAX as u64,
Self::B8 => u8::MAX as u64,
}
}
pub const fn bytes(&self) -> usize {
match self {
Self::B64 => 8,
Self::B32 => 4,
Self::B16 => 2,
Self::B8 => 1,
}
}
}
macro_rules! def_regs {
($($val:literal : $B64:ident $B32:ident $B16:ident $B8:ident $($B8H:ident=$hval:expr)?,)*) => {
$(
#[allow(non_upper_case_globals)]
pub const $B64: RegMode = RegMode { reg: Reg($val), width: BitWidth::B64, high: false };
#[allow(non_upper_case_globals)]
pub const $B32: RegMode = RegMode { reg: Reg($val), width: BitWidth::B32, high: false };
#[allow(non_upper_case_globals)]
pub const $B16: RegMode = RegMode { reg: Reg($val), width: BitWidth::B16, high: false };
#[allow(non_upper_case_globals)]
pub const $B8 : RegMode = RegMode { reg: Reg($val), width: BitWidth::B8, high: false };
$(
#[allow(non_upper_case_globals)]
pub const $B8H: RegMode = RegMode { reg: $hval.reg, width: BitWidth::B8, high: true };
)?
)*
impl RegMode {
pub fn parse(s: &str) -> Option<Self> {
Some(match s.to_lowercase().as_str() {
$(
stringify!($B64) => $B64,
stringify!($B32) => $B32,
stringify!($B16) => $B16,
stringify!($B8 ) => $B8,
$(
stringify!($B8H) => $B8H,
)?
)*
_ => return None,
})
}
}
};
}
use def_regs;
+76
View File
@@ -0,0 +1,76 @@
use super::*;
fn eq(expected: impl AsRef<[u8]>, asm: Instr) {
let expected = expected.as_ref();
let mut encoder = Encoder::new(0);
if let Err(e) = encoder.asm(asm) {
panic!("expected {expected:x?}, failed to compile: {}", e.msg);
}
let res = encoder.data;
assert_eq!(expected, &res[..], "expected {expected:x?}, got {res:x?}");
}
#[test]
fn reg_reg() {
// used objdump on some nasm compiled assembly
eq([0x48, 0x89, 0xd8], mov(rax, rbx));
eq([0x89, 0xd8], mov(eax, ebx));
eq([0x66, 0x89, 0xd8], mov(ax, bx));
eq([0x88, 0xd8], mov(al, bl));
eq([0x88, 0xfc], mov(ah, bh));
eq([0x88, 0xf8], mov(al, bh));
eq([0x88, 0xdc], mov(ah, bl));
eq([0x40, 0x88, 0xe7], mov(dil, spl));
eq([0x4d, 0x89, 0xc8], mov(r8, r9));
eq([0x45, 0x89, 0xc8], mov(r8d, r9d));
eq([0x66, 0x45, 0x89, 0xc8], mov(r8w, r9w));
eq([0x45, 0x88, 0xc8], mov(r8b, r9b));
eq([0x49, 0x89, 0xc0], mov(r8, rax));
eq([0x4c, 0x89, 0xc0], mov(rax, r8));
eq([0x4d, 0x89, 0xd1], mov(r9, r10));
eq([0x4d, 0x89, 0xe0], mov(r8, r12));
}
#[test]
fn reg_imm() {
eq(
[0x49, 0xbf, 0xf0, 0xde, 0xbc, 0x9a, 0x78, 0x56, 0x34, 0x12],
mov(r15, 0x123456789abcdef0),
);
eq(
[0x49, 0xb8, 0xf0, 0xde, 0xbc, 0x9a, 0x78, 0x56, 0x34, 0x12],
mov(r8, 0x123456789abcdef0),
);
eq(
[0x49, 0xb9, 0xf0, 0xde, 0xbc, 0x9a, 0x78, 0x56, 0x34, 0x12],
mov(r9, 0x123456789abcdef0),
);
eq([0x41, 0xb9, 0x78, 0x56, 0x34, 0x12], mov(r9d, 0x12345678));
eq([0x66, 0x41, 0xb9, 0x34, 0x12], mov(r9w, 0x1234));
eq([0x41, 0xb1, 0x12], mov(r9b, 0x12));
eq([0x41, 0xb0, 0x12], mov(r8b, 0x12));
eq([0x41, 0xb7, 0x12], mov(r15b, 0x12));
eq(
[0x48, 0xb8, 0xf0, 0xde, 0xbc, 0x9a, 0x78, 0x56, 0x34, 0x12],
mov(rax, 0x123456789abcdef0),
);
eq(
[0x48, 0xbb, 0xf0, 0xde, 0xbc, 0x9a, 0x78, 0x56, 0x34, 0x12],
mov(rbx, 0x123456789abcdef0),
);
eq(
[0x48, 0xbf, 0xf0, 0xde, 0xbc, 0x9a, 0x78, 0x56, 0x34, 0x12],
mov(rdi, 0x123456789abcdef0),
);
eq([0xbb, 0x78, 0x56, 0x34, 0x12], mov(ebx, 0x12345678));
eq([0x66, 0xbb, 0x34, 0x12], mov(bx, 0x1234));
eq([0xb3, 0x12], mov(bl, 0x12));
eq([0xb7, 0x12], mov(bh, 0x12));
eq([0xb4, 0x12], mov(ah, 0x12));
eq([0x40, 0xb7, 0x12], mov(dil, 0x12));
}