From ceebcdc0e33055484092f42af5cfddc84d1fd8ba Mon Sep 17 00:00:00 2001 From: Shadow Cat Date: Fri, 12 Jun 2026 05:09:38 -0400 Subject: [PATCH] tests, but at what cost --- src/arch/x86_64/encode.rs | 79 +++++++++++++--------- src/arch/x86_64/reg.rs | 120 ++++++++++++++++++++++++++++----- src/arch/x86_64/test/bin.rs | 2 +- src/arch/x86_64/test/mod.rs | 2 + src/arch/x86_64/test/nasm.rs | 127 +++++++++++++++++++++++++++++++++++ src/arch/x86_64/test/reg.rs | 3 +- src/arch/x86_64/util.rs | 12 ++-- 7 files changed, 288 insertions(+), 57 deletions(-) create mode 100644 src/arch/x86_64/test/nasm.rs diff --git a/src/arch/x86_64/encode.rs b/src/arch/x86_64/encode.rs index d8e77ab..29a64b1 100644 --- a/src/arch/x86_64/encode.rs +++ b/src/arch/x86_64/encode.rs @@ -15,7 +15,8 @@ pub struct Code { #[derive(Clone, Copy)] pub struct Mem { pub reg: Reg, - pub disp: u32, + pub disp: i32, + pub width: Width, } #[derive(Clone, Copy)] @@ -31,8 +32,8 @@ pub enum RegMem { Mem(Mem), } -pub fn mem(reg: Reg, disp: u32) -> Mem { - Mem { reg, disp } +pub fn mem(reg: Reg, disp: i32, width: Width) -> Mem { + Mem { reg, disp, width } } impl Code { @@ -40,7 +41,7 @@ impl Code { let dst = dst.into(); let src = src.into(); match dst { - RegMem::Reg(dst) => match src { + RegMem::Reg(mut dst) => match src { RegImmMem::Reg(src) => { if dst.width() != src.width() { return Err("src and dst are not same width".into()); @@ -57,13 +58,17 @@ impl Code { self.bytes.push(modrm_regs(src, dst)); } RegImmMem::Imm(src) => { + let src_width = Width::fit(src); + if src_width > dst.width() { + return Err("immediate cannot fit in register".into()); + } self.prefix16(dst); + if src_width <= Width::B32 { + dst.lower64(); + } if dst.requires_rex() { self.bytes.push(rex(dst.width(), 0, 0, dst)); } - if src > dst.width().max() { - return Err("immediate cannot fit in register".into()); - } let opcode = 0xb0 | ((dst.width().gt8() as u8) << 3); self.bytes.push(opcode | dst.base()); self.bytes.extend(&src.to_le_bytes()[..dst.width().bytes()]); @@ -73,14 +78,22 @@ impl Code { RegMem::Mem(dst) => match src { RegImmMem::Reg(src) => todo!(), RegImmMem::Imm(src) => { - if src > u32::MAX as u64 { + let src_width = Width::fit(src); + if src_width == Width::B64 { return Err("cannot move 64 bit immediate into memory".into()); } - let src = src as u32; - - self.bytes.extend([rex(1, dst.reg, 0, 0), 0xc7]); + match dst.reg.width() { + Width::B8 | Width::B16 => return Err("invalid register width".into()), + Width::B32 => self.bytes.push(0x67), + Width::B64 => (), + } + self.prefix16(src_width); + if dst.reg.requires_mem_rex() { + self.bytes.push(rex(src_width, 0, 0, dst.reg)); + } + self.bytes.push(0xc6 | (src_width != Width::B8) as u8); self.modrm_regdisp(dst.reg, dst.disp); - self.bytes.extend(src.to_le_bytes()); + self.bytes.extend(&src.to_le_bytes()[..src_width.bytes()]); } RegImmMem::Mem(_) => return Err("cannot move memory to memory".into()), }, @@ -100,18 +113,16 @@ impl Code { Width::B16 => {} _ => return Err("register must be 64 or 16 bit".into()), }, - RegImmMem::Imm(imm) => match imm.try_into() { - Ok(imm) => { - const U8: u32 = 2 << 8; - if let 0..U8 = imm { - self.bytes.push(0x6a); - self.bytes.push(imm as u8); - } else { - self.bytes.push(0x68); - self.bytes.extend(imm.to_le_bytes()); - } + RegImmMem::Imm(imm) => match Width::fit(imm) { + Width::B8 => { + self.bytes.push(0x6a); + self.bytes.push(imm as u8); } - Err(_) => return Err("immediate must be 32 bit".into()), + Width::B16 | Width::B32 => { + self.bytes.push(0x68); + self.bytes.extend((imm as u32).to_le_bytes()); + } + Width::B64 => return Err("immediate must be 32 bit or less".into()), }, RegImmMem::Mem(mem) => todo!(), } @@ -170,18 +181,24 @@ impl Code { } } - fn modrm_regdisp(&mut self, reg: Reg, disp: u32) { - let disp8 = disp < u8::MAX as u32; - let mod_ = if disp8 { 0b01 } else { 0b10 }; + fn modrm_regdisp(&mut self, reg: Reg, disp: i32) { + const I8_MIN: i32 = i8::MIN as i32; + const I8_MAX: i32 = i8::MAX as i32; + let mod_ = match disp { + 0 => 0b00, + I8_MIN..=I8_MAX => 0b01, + _ => 0b10, + }; self.bytes.push(modrm(mod_, 0, reg.base())); if reg.val() == rsp.val() { // SIB self.bytes.push(0x24); } - if disp8 { - self.bytes.push(disp as u8); - } else { - self.bytes.extend(disp.to_le_bytes()); + match mod_ { + 0b00 => (), + 0b01 => self.bytes.push(disp as u8), + 0b10 => self.bytes.extend(disp.to_le_bytes()), + _ => unreachable!(), } } @@ -245,6 +262,6 @@ impl From for RegImmMem { impl From for RegImmMem { fn from(value: i32) -> Self { - Self::Imm(value as u32 as u64) + Self::Imm(value as u64) } } diff --git a/src/arch/x86_64/reg.rs b/src/arch/x86_64/reg.rs index e345396..8906300 100644 --- a/src/arch/x86_64/reg.rs +++ b/src/arch/x86_64/reg.rs @@ -5,33 +5,34 @@ pub struct Reg { width: Width, } -#[derive(Debug, Clone, Copy, PartialEq)] +#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)] +#[repr(u8)] pub enum Width { - B64, - B32, - B16, - B8, + B8 = 0, + B16 = 1, + B32 = 2, + B64 = 3, } -def_regs! { Reg; +def_regs! { 0b0000 : rax eax ax al, - 0b0001 : rcx ecx cx cl, + 0b0001 : rcx ecx cx cl !_, 0b0010 : rdx edx dx dl, 0b0011 : rbx ebx bx bl, - 0b0100 : rsp esp sp spl norex=ah, + 0b0100 : rsp esp sp spl norex=ah !_, 0b0101 : rbp ebp bp bpl norex=ch, - 0b0110 : rsi esi si sil norex=dh, + 0b0110 : rsi esi si sil norex=dh !_, 0b0111 : rdi edi di dil norex=bh, 0b1000 : r8 r8d r8w r8b, - 0b1001 : r9 r9d r9w r9b, + 0b1001 : r9 r9d r9w r9b !_, 0b1010 : r10 r10d r10w r10b, 0b1011 : r11 r11d r11w r11b, 0b1100 : r12 r12d r12w r12b, 0b1101 : r13 r13d r13w r13b, 0b1110 : r14 r14d r14w r14b, - 0b1111 : r15 r15d r15w r15b, + 0b1111 : r15 r15d r15w r15b !_, } impl Reg { @@ -49,16 +50,29 @@ impl Reg { self.high } + pub fn val(&self) -> u8 { + self.val + } + pub fn width(&self) -> Width { self.width } + /// if self has 64 bit width, changes width to 32 bit + pub fn lower64(&mut self) { + self.width.lower64() + } + pub fn requires_rex(&self) -> bool { self.gt8() || self.width == Width::B64 || (self.gt4() && self.width == Width::B8 && !self.high) } + pub fn requires_mem_rex(&self) -> bool { + self.gt8() || (self.gt4() && self.width == Width::B8 && !self.high) + } + pub fn incompatible(&self, other: &Reg) -> bool { (self.requires_rex() && other.high) || (self.high && other.requires_rex()) } @@ -77,6 +91,13 @@ impl Width { Self::B8 { .. } => u8::MAX as u64, } } + + pub fn lower64(&mut self) { + if matches!(self, Width::B64) { + *self = Width::B32; + } + } + pub const fn bytes(&self) -> usize { match self { Self::B64 => 8, @@ -85,29 +106,75 @@ impl Width { Self::B8 { .. } => 1, } } + + pub const fn fit(val: u64) -> Self { + const B8: u64 = 1 << 8; + const B16: u64 = 1 << 16; + const B32: u64 = 1 << 32; + match val { + ..B8 => Self::B8, + B8..B16 => Self::B16, + B16..B32 => Self::B32, + B32.. => Self::B64, + } + } + + pub const fn fiti(val: u64) -> Self { + match val { + ..0x80 => Self::B8, + 0x80..0x8000 => Self::B16, + 0x8000..0x8000_0000 => Self::B32, + 0x8000_0000.. => Self::B64, + } + } + /// greater than 8 bits pub const fn gt8(&self) -> bool { !matches!(self, Self::B8) } } +macro_rules! filter { + ($($filtered:ident)*; ! $_:tt $($item:ident)*; $($rest:tt)*) => { + filter!($($filtered)* $($item)*; $($rest)*) + }; + ($($filtered:ident)*; $($item:ident)*; $($rest:tt)*) => { + filter!($($filtered)*; $($rest)*) + }; + ($($filtered:ident)*;) => { + [$($filtered, )*] + }; +} +use filter; + macro_rules! def_regs { - ($Struct: ident; $($val:literal : $B64:ident $B32:ident $B16:ident $B8:ident $(norex=$B8H:ident)?,)*) => { + ($($val:literal : $B64:ident $B32:ident $B16:ident $B8:ident $(norex=$B8H:ident)? $(!$imp:tt)?,)*) => { $( #[allow(non_upper_case_globals)] - pub const $B64: $Struct = $Struct::new($val, Width::B64, false); + pub const $B64: Reg = Reg::new($val, Width::B64, false); #[allow(non_upper_case_globals)] - pub const $B32: $Struct = $Struct::new($val, Width::B32, false); + pub const $B32: Reg = Reg::new($val, Width::B32, false); #[allow(non_upper_case_globals)] - pub const $B16: $Struct = $Struct::new($val, Width::B16, false); + pub const $B16: Reg = Reg::new($val, Width::B16, false); #[allow(non_upper_case_globals)] - pub const $B8 : $Struct = $Struct::new($val, Width::B8 , false); + pub const $B8 : Reg = Reg::new($val, Width::B8 , false); $( #[allow(non_upper_case_globals)] - pub const $B8H: $Struct = $Struct::new($val, Width::B8, true); + pub const $B8H: Reg = Reg::new($val, Width::B8, true); )? )* - impl $Struct { + + impl Reg { + #[cfg(test)] + pub const ALL: &[Reg] = &[ + $( $B64, $B32, $B16, $B8, $($B8H,)? )* + ]; + + #[cfg(test)] + pub const IMPORTANT: &[Reg] = & + filter!(; $($(!$imp)? $B64 $B32 $B16 $B8 $($B8H)?; )* ) + ; + pub fn parse(s: &str) -> Option { Some(match s.to_lowercase().as_str() { $( @@ -123,6 +190,23 @@ macro_rules! def_regs { }) } } + impl std::fmt::Display for Reg { + #[allow(non_upper_case_globals)] + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", match *self { + $( + $B64 => stringify!($B64), + $B32 => stringify!($B32), + $B16 => stringify!($B16), + $B8 => stringify!($B8), + $( + $B8H => stringify!($B8H), + )? + )* + _ => "UNKNOWN", + }) + } + } }; } diff --git a/src/arch/x86_64/test/bin.rs b/src/arch/x86_64/test/bin.rs index 72141ce..3606303 100644 --- a/src/arch/x86_64/test/bin.rs +++ b/src/arch/x86_64/test/bin.rs @@ -94,7 +94,7 @@ fn windows() -> Result<(), CompilerMsg> { c.lea(rdx, text_sym); c.mov(r8d, text.len() as u64)?; c.lea(r9, written); - c.mov(mem(rsp, 0x20), 0)?; + c.mov(mem(rsp, 0x20, Width::B32), 0)?; c.call_mem(write_file); // exit c.mov(ecx, 39)?; diff --git a/src/arch/x86_64/test/mod.rs b/src/arch/x86_64/test/mod.rs index 968b0ac..89572ae 100644 --- a/src/arch/x86_64/test/mod.rs +++ b/src/arch/x86_64/test/mod.rs @@ -1,5 +1,7 @@ pub mod bin; #[cfg(test)] +mod nasm; +#[cfg(test)] mod reg; #[cfg(test)] use super::*; diff --git a/src/arch/x86_64/test/nasm.rs b/src/arch/x86_64/test/nasm.rs new file mode 100644 index 0000000..fde3578 --- /dev/null +++ b/src/arch/x86_64/test/nasm.rs @@ -0,0 +1,127 @@ +use crate::arch::x86_64::*; +use std::{fs::OpenOptions, io::Write, process::Command}; + +const DISPS: &[i32] = &[ + 0x0, + 0x1, + i8::MIN as i32, + i8::MAX as i32, + i16::MIN as i32, + i16::MAX as i32, + i32::MAX, +]; + +const IMMS: &[u64] = &[ + 0x0, + 0x1, + u8::MAX as u64, + u8::MAX as u64 + 1, + u16::MAX as u64, + u16::MAX as u64 + 1, + u32::MAX as u64, + u32::MAX as u64 + 1, + // nasm likes to think u64::MAX is -1i32 for some reason + i64::MAX as u64, +]; + +#[test] +fn mov() { + for ® in Reg::IMPORTANT { + for &disp in DISPS { + for &imm in IMMS { + let width = Width::fit(imm); + let size = match width { + Width::B8 => "BYTE", + Width::B16 => "WORD", + Width::B32 => "DWORD", + Width::B64 => "QWORD", + }; + let ddisp = (disp as i64).abs(); + let sign = if disp < 0 { '-' } else { '+' }; + eq!( + format!("mov {size} [{reg}{sign}0x{ddisp:x}], 0x{imm:x}"), + mov(mem(reg, disp, width), imm) + ); + } + } + } + + for &r1 in Reg::IMPORTANT { + for &r2 in Reg::IMPORTANT { + eq!(format!("mov {r1}, {r2}"), mov(r1, r2)); + } + } + + for &r1 in Reg::IMPORTANT { + for &imm in IMMS { + eq!(format!("mov {r1}, 0x{imm:x}"), mov(r1, imm)); + } + } +} + +macro_rules! eq { + ($asm:expr, $instr:ident $args:tt $(,)?) => { + let asm = $asm; + let expected = nasm(asm.as_ref()); + let mut code = Code::default(); + let res = code.$instr $args; + match (expected, res) { + (Ok(_), Err(e)) => { + panic!("{asm}: failed to compile: {}", e.msg); + } + (Err(e), Ok(_)) => { + let res = &code.bytes[..]; + panic!("{asm}: should not have compiled:\n{e}\ngot: {res:x?}"); + } + (Err(_), Err(_)) => (), + (Ok(expected), Ok(_)) => { + let res = &code.bytes[..]; + if expected != res { + panic!("{asm}: expected {expected:x?}, got {res:x?}") + } + } + } + }; +} +use eq; + +fn nasm(input: &str) -> Result, String> { + let fin = "/tmp/69420nasm_in.asm"; + let fout = "/tmp/69420nasm_out.o"; + let input = "result:".to_string() + input; + write(fin, input.as_bytes()); + run(["nasm", "-w+error", "-felf64", fin, &format!("-o{fout}")])?; + let output = run(["objdump", "--no-addresses", "-dw", "-Mintel", fout])?; + let mut iter = output.lines().skip_while(|l| !l.contains("result")).skip(1); + let res_line = iter.next().unwrap().trim(); + let end = res_line.find("\t").unwrap(); + let res_line = &res_line[..end]; + let bytes = res_line + .trim() + .split(" ") + .map(|s| u8::from_str_radix(s, 16).unwrap()) + .collect(); + Ok(bytes) +} + +fn run(input: [&str; N]) -> Result { + let path = input[0]; + let mut cmd = Command::new(path); + cmd.args(&input[1..]); + let output = cmd.output().expect("failed to run"); + if output.status.code().unwrap() != 0 { + return Err(output.stderr.try_into().unwrap()); + } + Ok(output.stdout.try_into().unwrap()) +} + +fn write(path: &str, binary: &[u8]) { + let mut file = OpenOptions::new() + .create(true) + .write(true) + .truncate(true) + .open(path) + .expect("Failed to create file"); + file.write_all(binary).expect("Failed to write to file"); + file.sync_all().expect("Failed to sync file"); +} diff --git a/src/arch/x86_64/test/reg.rs b/src/arch/x86_64/test/reg.rs index abbaae5..3a8ac1e 100644 --- a/src/arch/x86_64/test/reg.rs +++ b/src/arch/x86_64/test/reg.rs @@ -13,9 +13,10 @@ macro_rules! eq { }; } +// used objdump on some nasm compiled assembly + #[test] fn mov_reg_reg() { - // used objdump on some nasm compiled assembly eq!([0x48, 0x89, 0xd8], mov(rax, rbx)); eq!([0x89, 0xd8], mov(eax, ebx)); eq!([0x66, 0x89, 0xd8], mov(ax, bx)); diff --git a/src/arch/x86_64/util.rs b/src/arch/x86_64/util.rs index aaf2da3..5a14e59 100644 --- a/src/arch/x86_64/util.rs +++ b/src/arch/x86_64/util.rs @@ -22,7 +22,7 @@ pub fn rex(w: impl RexBit, r: impl RexBit, x: impl RexBit, b: impl RexBit) -> u8 #[inline(always)] pub fn bit(val: impl RexBit, pos: u8) -> u8 { - (val.val() as u8) << pos + (val.rex() as u8) << pos } /// assumes the next instruction is directly after @@ -33,29 +33,29 @@ pub fn addr_offset(pos: usize, addr: u64) -> [u8; 4] { } pub trait RexBit { - fn val(self) -> bool; + fn rex(self) -> bool; } impl RexBit for u8 { - fn val(self) -> bool { + fn rex(self) -> bool { self != 0 } } impl RexBit for bool { - fn val(self) -> bool { + fn rex(self) -> bool { self } } impl> RexBit for R { - fn val(self) -> bool { + fn rex(self) -> bool { self.into().gt8() } } impl RexBit for Width { - fn val(self) -> bool { + fn rex(self) -> bool { self == Width::B64 } }