use crate::ClientSender; use dashmap::DashMap; use openworm::net::{ ClientMsg, ClientMsgInst, RecvHandler, RequestId, RequestMsg, SERVER_NAME, ServerMsg, ServerMsgInst, SkipServerVerification, recv_uni, send_uni, }; use quinn::{ ClientConfig, Connection, Endpoint, IdleTimeout, TransportConfig, crypto::rustls::QuicClientConfig, rustls::pki_types::CertificateDer, }; use std::{ net::{Ipv6Addr, SocketAddr, SocketAddrV6, ToSocketAddrs}, sync::Arc, time::Duration, }; use tokio::sync::{mpsc::UnboundedSender, oneshot}; pub const CLIENT_SOCKET: SocketAddr = SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0)); pub struct ConnectInfo { pub url: String, pub cert: Vec, } #[derive(Clone)] pub struct NetHandle { send: UnboundedSender, event_sender: ClientSender, } type NetResult = Result; pub enum NetCtrlMsg { Send(ClientMsg), Request(ClientMsg, oneshot::Sender), RequestSync(ClientMsg, Box), Exit, } type Resp = Result; impl NetHandle { fn send_(&self, msg: NetCtrlMsg) { let _ = self.send.send(msg); } pub fn send(&self, msg: impl Into) { self.send_(NetCtrlMsg::Send(msg.into())); } pub async fn request(&self, msg: R) -> Resp { let (send, recv) = oneshot::channel(); self.send_(NetCtrlMsg::Request(msg.into(), send)); let Ok(recv) = recv.await else { return Err(()) }; if let Some(res) = R::result(recv) { Ok(res) } else { Err(()) } } pub fn request_sync(&self, msg: R) -> SyncRecv { let (send, recv) = oneshot::channel(); let sender = self.event_sender.clone(); self.send_(NetCtrlMsg::RequestSync( msg.into(), Box::new(move |msg| { let _ = send.send(if let Some(res) = R::result(msg) { Ok(res) } else { Err(()) }); sender.run(); }), )); SyncRecv:: { recv } } pub fn exit(self) { self.send_(NetCtrlMsg::Exit); } } pub struct SyncRecv { recv: oneshot::Receiver>, } impl SyncRecv { pub fn try_recv(&mut self) -> Option> { match self.recv.try_recv() { Ok(res) => Some(res), Err(oneshot::error::TryRecvError::Empty) => None, Err(oneshot::error::TryRecvError::Closed) => Some(Err(())), } } } async fn connection_cert( addr: SocketAddr, cert: CertificateDer<'_>, ) -> NetResult<(Endpoint, Connection)> { let mut roots = quinn::rustls::RootCertStore::empty(); roots .add(cert) .map_err(|e| format!("Invalid Certificate: {e:?}"))?; let client_crypto = quinn::rustls::ClientConfig::builder() .with_root_certificates(roots) .with_no_client_auth(); let client_config = ClientConfig::new(Arc::new( QuicClientConfig::try_from(client_crypto).map_err(|e| e.to_string())?, )); let mut endpoint = quinn::Endpoint::client(CLIENT_SOCKET).map_err(|e| e.to_string())?; endpoint.set_default_client_config(client_config); let conn = endpoint .connect(addr, SERVER_NAME) .map_err(|e| e.to_string())? .await .map_err(|e| format!("failed to connect: {}", e))?; Ok((endpoint, conn)) } async fn connection_no_cert(addr: SocketAddr) -> NetResult<(Endpoint, Connection)> { let mut endpoint = Endpoint::client(CLIENT_SOCKET).map_err(|e| e.to_string())?; let quic = QuicClientConfig::try_from( quinn::rustls::ClientConfig::builder() .dangerous() .with_custom_certificate_verifier(SkipServerVerification::new()) .with_no_client_auth(), ) .map_err(|e| e.to_string())?; let mut config = ClientConfig::new(Arc::new(quic)); let mut transport = TransportConfig::default(); transport.keep_alive_interval(Some(Duration::from_secs(5))); transport.max_idle_timeout(Some( IdleTimeout::try_from(Duration::from_secs(10)).unwrap(), )); config.transport_config(transport.into()); endpoint.set_default_client_config(config); // connect to server let con = endpoint .connect(addr, SERVER_NAME) .map_err(|e| e.to_string())? .await .map_err(|e| e.to_string())?; Ok((endpoint, con)) } impl NetHandle { pub async fn connect( msg: impl MsgHandler, info: ConnectInfo, event_sender: ClientSender, ) -> Result { let (send, mut ui_recv) = tokio::sync::mpsc::unbounded_channel::(); let cert = CertificateDer::from_slice(&info.cert); let addr = info .url .to_socket_addrs() .map_err(|e| e.to_string())? .next() .ok_or("no addresses found".to_string())?; let (endpoint, conn) = connection_cert(addr, cert).await?; let conn_ = conn.clone(); let mut req_id = RequestId::first(); let recv = Arc::new(ServerRecv { msg, requests_sync: DashMap::default(), requests: DashMap::default(), }); tokio::spawn(recv_uni(conn_, recv.clone())); tokio::spawn(async move { while let Some(msg) = ui_recv.recv().await { let request_id = req_id.next(); match msg { NetCtrlMsg::Send(msg) => { let msg = ClientMsgInst { id: request_id, msg, }; if send_uni(&conn, msg).await.is_err() { println!("disconnected from server"); break; } } NetCtrlMsg::Request(msg, send) => { let msg = ClientMsgInst { id: request_id, msg, }; recv.requests.insert(request_id, send); if send_uni(&conn, msg).await.is_err() { println!("disconnected from server"); break; } } NetCtrlMsg::RequestSync(msg, f) => { let msg = ClientMsgInst { id: request_id, msg, }; recv.requests_sync.insert(request_id, f); if send_uni(&conn, msg).await.is_err() { println!("disconnected from server"); break; } } NetCtrlMsg::Exit => { conn.close(0u32.into(), &[]); endpoint.wait_idle().await; break; } } } }); Ok(NetHandle { send, event_sender }) } } pub trait MsgHandler: Sync + Send + 'static { fn run(&self, msg: ServerMsg) -> impl Future + Send; } impl MsgHandler for F where for<'a> F::CallRefFuture<'a>: Send, { async fn run(&self, msg: ServerMsg) { self(msg).await; } } struct ServerRecv { requests: DashMap>, requests_sync: DashMap>, msg: F, } impl RecvHandler for ServerRecv { async fn msg(&self, resp: ServerMsgInst) { if let Some(id) = resp.id { if let Some((_, send)) = self.requests.remove(&id) { let _ = send.send(resp.msg); } else if let Some((_, f)) = self.requests_sync.remove(&id) { f(resp.msg) } } else { self.msg.run(resp.msg).await; } } }