use openworm::net::{ ClientMsgInst, RecvHandler, RequestId, SERVER_NAME, SendResult, ServerMsg, ServerMsgInst, recv_uni, send_uni, }; use quinn::{ Connection, Endpoint, ServerConfig, rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer}, }; use std::{fs, path::Path}; use std::{ net::{Ipv6Addr, SocketAddr, SocketAddrV6}, sync::Arc, }; use tokio::task::JoinHandle; use tracing::Instrument; pub const SERVER_HOST: Ipv6Addr = Ipv6Addr::UNSPECIFIED; pub fn init_endpoint(port: u16, data_path: &Path) -> Endpoint { let cert_path = data_path.join("cert.der"); let key_path = data_path.join("key.der"); let (cert, key) = match fs::read(&cert_path).and_then(|x| Ok((x, fs::read(&key_path)?))) { Ok((cert, key)) => ( CertificateDer::from(cert), PrivateKeyDer::try_from(key).unwrap(), ), Err(ref e) if e.kind() == std::io::ErrorKind::NotFound => { let cert = rcgen::generate_simple_self_signed([SERVER_NAME.into()]).unwrap(); let key = PrivatePkcs8KeyDer::from(cert.signing_key.serialize_der()); let cert = cert.cert.into(); fs::create_dir_all(data_path).expect("failed to create certificate directory"); fs::write(&cert_path, &cert).expect("failed to write certificate"); fs::write(&key_path, key.secret_pkcs8_der()).expect("failed to write private key"); (cert, key.into()) } Err(e) => { panic!("failed to read certificate: {}", e); } }; print!("cert hex: "); for x in cert.iter() { print!("{:02x}", x); } println!(); let server_config = ServerConfig::with_single_cert(vec![cert], key).unwrap(); let server_socket: SocketAddr = SocketAddr::V6(SocketAddrV6::new(SERVER_HOST, port, 0, 0)); quinn::Endpoint::server(server_config, server_socket).unwrap() } #[derive(Clone)] pub struct ClientSender { conn: Connection, } impl ClientSender { pub fn remote(&self) -> SocketAddr { self.conn.remote_address() } pub fn replier(&self, id: RequestId) -> ClientReplier { ClientReplier { conn: self.conn.clone(), req_id: id, } } pub async fn send(&self, msg: impl Into) -> SendResult { let msg = ServerMsgInst { id: None, msg: msg.into(), }; send_uni(&self.conn, msg).await } } pub struct ClientReplier { conn: Connection, req_id: RequestId, } impl ClientReplier { pub async fn send(&self, msg: impl Into) { let msg = ServerMsgInst { id: Some(self.req_id), msg: msg.into(), }; let _ = send_uni(&self.conn, msg).await; } } pub trait ConAccepter: Send + Sync + 'static { fn accept( &self, send: ClientSender, ) -> impl Future> + Send; } pub fn listen( port: u16, data_path: &Path, accepter: impl ConAccepter, ) -> (Endpoint, JoinHandle<()>) { let accepter = Arc::new(accepter); let endpoint = init_endpoint(port, data_path); let res = endpoint.clone(); let handle = tokio::spawn(async move { println!("listening on {}", endpoint.local_addr().unwrap()); let mut tasks = Vec::new(); while let Some(conn) = endpoint.accept().await { let fut = handle_connection(conn, accepter.clone()); tasks.push(tokio::spawn(async move { if let Err(e) = fut.await { eprintln!("connection failed: {reason}", reason = e) } })); } for task in tasks { let _ = task.await; } }); (res, handle) } async fn handle_connection( conn: quinn::Incoming, accepter: Arc, ) -> std::io::Result<()> { let conn = conn.await?; let handler = Arc::new(accepter.accept(ClientSender { conn: conn.clone() }).await); handler.connect().await; let span = tracing::info_span!( "connection", remote = %conn.remote_address(), protocol = %conn .handshake_data() .unwrap() .downcast::().unwrap() .protocol .map_or_else(|| "".into(), |x| String::from_utf8_lossy(&x).into_owned()) ); async { let res = recv_uni(conn, handler.clone()).await; handler.disconnect(res).await; } .instrument(span) .await; Ok(()) }