diff --git a/src/bin/client/net.rs b/src/bin/client/net.rs index d4b4f28..fd9f453 100644 --- a/src/bin/client/net.rs +++ b/src/bin/client/net.rs @@ -6,7 +6,7 @@ use openworm::net::{ }; use quinn::{ ClientConfig, Connection, Endpoint, IdleTimeout, TransportConfig, - crypto::rustls::QuicClientConfig, + crypto::rustls::QuicClientConfig, rustls::pki_types::CertificateDer, }; use std::{ net::{Ipv6Addr, SocketAddr, SocketAddrV6, ToSocketAddrs}, @@ -21,6 +21,7 @@ pub const CLIENT_SOCKET: SocketAddr = pub struct ConnectInfo { pub url: String, + pub cert: Vec, } pub struct NetHandle { @@ -90,31 +91,20 @@ impl RequestMsg for CreateAccount { } } -// async fn connection_cert(addr: SocketAddr) -> NetResult { -// let dirs = directories_next::ProjectDirs::from("", "", "openworm").unwrap(); -// let mut roots = quinn::rustls::RootCertStore::empty(); -// match fs::read(dirs.data_local_dir().join("cert.der")) { -// Ok(cert) => { -// roots.add(CertificateDer::from(cert))?; -// } -// Err(ref e) if e.kind() == ErrorKind::NotFound => { -// eprintln!("local server certificate not found"); -// } -// Err(e) => { -// eprintln!("failed to open local server certificate: {}", e); -// } -// } -// let client_crypto = quinn::rustls::ClientConfig::builder() -// .with_root_certificates(roots) -// .with_no_client_auth(); -// let client_config = ClientConfig::new(Arc::new(QuicClientConfig::try_from(client_crypto)?)); -// let mut endpoint = quinn::Endpoint::client(SocketAddr::from_str("[::]:0").unwrap())?; -// endpoint.set_default_client_config(client_config); -// endpoint -// .connect(addr, SERVER_NAME)? -// .await -// .map_err(|e| format!("failed to connect: {}", e)) -// } +async fn connection_cert(addr: SocketAddr, cert: CertificateDer) -> NetResult { + let mut roots = quinn::rustls::RootCertStore::empty(); + roots.add(cert); + let client_crypto = quinn::rustls::ClientConfig::builder() + .with_root_certificates(roots) + .with_no_client_auth(); + let client_config = ClientConfig::new(Arc::new(QuicClientConfig::try_from(client_crypto)?)); + let mut endpoint = quinn::Endpoint::client(SocketAddr::from_str("[::]:0").unwrap())?; + endpoint.set_default_client_config(client_config); + endpoint + .connect(addr, SERVER_NAME)? + .await + .map_err(|e| format!("failed to connect: {}", e)) +} async fn connection_no_cert(addr: SocketAddr) -> NetResult<(Endpoint, Connection)> { let mut endpoint = Endpoint::client(CLIENT_SOCKET).map_err(|e| e.to_string())?; @@ -147,59 +137,62 @@ async fn connection_no_cert(addr: SocketAddr) -> NetResult<(Endpoint, Connection Ok((endpoint, con)) } -pub async fn connect(msg: impl MsgHandler, info: ConnectInfo) -> Result { - let (send, mut ui_recv) = tokio::sync::mpsc::unbounded_channel::(); +impl NetHandle { + pub async fn connect(msg: impl MsgHandler, info: ConnectInfo) -> Result { + let (send, mut ui_recv) = tokio::sync::mpsc::unbounded_channel::(); - let addr = info - .url - .to_socket_addrs() - .map_err(|e| e.to_string())? - .next() - .ok_or("no addresses found".to_string())?; - let (endpoint, conn) = connection_no_cert(addr).await?; - let conn_ = conn.clone(); + let cert = CertificateDer::from_slice(&info.cert); + let addr = info + .url + .to_socket_addrs() + .map_err(|e| e.to_string())? + .next() + .ok_or("no addresses found".to_string())?; + let (endpoint, conn) = connection_cert(addr).await?; + let conn_ = conn.clone(); - let mut req_id = RequestId::first(); - let recv = Arc::new(ServerRecv { - msg, - requests: DashMap::default(), - }); - tokio::spawn(recv_uni(conn_, recv.clone())); - tokio::spawn(async move { - while let Some(msg) = ui_recv.recv().await { - let request_id = req_id.next(); - match msg { - NetCtrlMsg::Send(msg) => { - let msg = ClientRequestMsg { - id: request_id, - msg: msg.into(), - }; - if send_uni(&conn, msg).await.is_err() { - println!("disconnected from server"); + let mut req_id = RequestId::first(); + let recv = Arc::new(ServerRecv { + msg, + requests: DashMap::default(), + }); + tokio::spawn(recv_uni(conn_, recv.clone())); + tokio::spawn(async move { + while let Some(msg) = ui_recv.recv().await { + let request_id = req_id.next(); + match msg { + NetCtrlMsg::Send(msg) => { + let msg = ClientRequestMsg { + id: request_id, + msg: msg.into(), + }; + if send_uni(&conn, msg).await.is_err() { + println!("disconnected from server"); + break; + } + } + NetCtrlMsg::Request(msg, send) => { + let msg = ClientRequestMsg { + id: request_id, + msg: msg.into(), + }; + recv.requests.insert(request_id, send); + if send_uni(&conn, msg).await.is_err() { + println!("disconnected from server"); + break; + } + } + NetCtrlMsg::Exit => { + conn.close(0u32.into(), &[]); + endpoint.wait_idle().await; break; } } - NetCtrlMsg::Request(msg, send) => { - let msg = ClientRequestMsg { - id: request_id, - msg: msg.into(), - }; - recv.requests.insert(request_id, send); - if send_uni(&conn, msg).await.is_err() { - println!("disconnected from server"); - break; - } - } - NetCtrlMsg::Exit => { - conn.close(0u32.into(), &[]); - endpoint.wait_idle().await; - break; - } } - } - }); + }); - Ok(NetHandle { send }) + Ok(NetHandle { send }) + } } pub trait MsgHandler: Sync + Send + 'static { diff --git a/src/bin/client/ui/connect.rs b/src/bin/client/ui/connect.rs index 0367366..18614ba 100644 --- a/src/bin/client/ui/connect.rs +++ b/src/bin/client/ui/connect.rs @@ -1,6 +1,6 @@ use openworm::net::{CreateAccount, CreateAccountResp}; -use crate::net::{self, ConnectInfo}; +use crate::net::{ConnectInfo, NetHandle}; use super::*; @@ -41,17 +41,23 @@ pub fn start(rsc: &mut Rsc) -> WeakWidget { pub fn create_account(rsc: &mut Rsc) -> WeakWidget { let url = field("", "server", rsc); let token = field("", "account creation token", rsc); + let cert = field("", "certificate hex", rsc); let username = field("", "username", rsc); let password = field("", "password", rsc); let create = Button::submit("create", rsc); rsc.events.register(create, Submit, move |ctx, rsc| { - create.disable(rsc); let url = rsc[url].content(); let token = rsc[token].content(); + let cert = rsc[cert].content(); + let Ok(cert) = decode_hex(&cert) else { + rsc[ctx.state.notif].inner = Some(werror("Invalid certificate hex", rsc)); + return; + }; let username = rsc[username].content(); let password = rsc[password].content(); let login_key = ctx.state.data.login_key(&username); + create.disable(rsc); rsc.spawn_task(async move |mut ctx| { let mut fail = move |reason| { ctx.update(move |ctx, rsc| { @@ -59,11 +65,11 @@ pub fn create_account(rsc: &mut Rsc) -> WeakWidget { create.enable(rsc); }) }; - let Ok(net) = net::connect( + let Ok(net) = NetHandle::connect( async |msg| { println!("msg recv :joy:"); }, - ConnectInfo { url }, + ConnectInfo { url, cert }, ) .await else { @@ -98,6 +104,7 @@ pub fn create_account(rsc: &mut Rsc) -> WeakWidget { wtext("Create Account").text_align(Align::CENTER).size(30), field_box(url, rsc), field_box(token, rsc), + field_box(cert, rsc), field_box(username, rsc), field_box(password, rsc), create, @@ -108,6 +115,13 @@ pub fn create_account(rsc: &mut Rsc) -> WeakWidget { .add(rsc) } +pub fn decode_hex(s: &str) -> Result, std::num::ParseIntError> { + (0..s.len()) + .step_by(2) + .map(|i| u8::from_str_radix(&s[i..i + 2], 16)) + .collect() +} + // pub fn connect_screen(client: &mut Client, ui: &mut Ui, state: &UiState) -> WeakWidget { // let Client { data, proxy, .. } = client; // let ip = field_widget(&data.ip, "ip", ui); diff --git a/src/bin/server/main.rs b/src/bin/server/main.rs index 69b65a8..415157a 100644 --- a/src/bin/server/main.rs +++ b/src/bin/server/main.rs @@ -42,8 +42,8 @@ fn main() { #[tokio::main] pub async fn run_server(port: u16) { let dir = DataDir::default(); - let path = dir.get(); - let db = Db::open(path.join("server_db")); + let path = dir.get().join("server"); + let db = Db::open(path.join("db")); let handler = ServerListener { senders: Default::default(), count: 0.into(), @@ -53,7 +53,7 @@ pub async fn run_server(port: u16) { let token = account_token(&db, ServerPerms::ACCOUNT_TOKENS); println!("no users found, token for admin: {token}"); } - let (endpoint, handle) = listen(port, path, handler); + let (endpoint, handle) = listen(port, &path, handler); let _ = ctrl_c().await; println!("stopping server"); println!("closing connections..."); diff --git a/src/bin/server/net.rs b/src/bin/server/net.rs index 9a945ba..3753532 100644 --- a/src/bin/server/net.rs +++ b/src/bin/server/net.rs @@ -37,19 +37,12 @@ pub fn init_endpoint(port: u16, data_path: &Path) -> Endpoint { panic!("failed to read certificate: {}", e); } }; - // let server_crypto = quinn::rustls::ServerConfig::builder() - // .with_no_client_auth() - // .with_single_cert(vec![cert], key) - // .unwrap(); - // - // let server_config = quinn::ServerConfig::with_crypto(Arc::new( - // QuicServerConfig::try_from(server_crypto).unwrap(), - // )); - + print!("cert hex: "); + for x in cert.iter() { + print!("{:x}", x); + } + println!(); let server_config = ServerConfig::with_single_cert(vec![cert], key).unwrap(); - // let transport_config = Arc::get_mut(&mut server_config.transport).unwrap(); - // transport_config.max_concurrent_uni_streams(0_u8.into()); - let server_socket: SocketAddr = SocketAddr::V6(SocketAddrV6::new(SERVER_HOST, port, 0, 0)); quinn::Endpoint::server(server_config, server_socket).unwrap() }