use crate::ClientEvent; use dashmap::DashMap; use openworm::net::{ ClientMsg, ClientRequestMsg, CreateAccount, CreateAccountResp, RecvHandler, RequestId, SERVER_NAME, ServerMsg, ServerRespMsg, SkipServerVerification, recv_uni, send_uni, }; use quinn::{ ClientConfig, Connection, Endpoint, IdleTimeout, TransportConfig, crypto::rustls::QuicClientConfig, rustls::pki_types::CertificateDer, }; use std::{ net::{Ipv6Addr, SocketAddr, SocketAddrV6, ToSocketAddrs}, sync::Arc, time::Duration, }; use tokio::sync::{mpsc::UnboundedSender, oneshot}; use winit::event_loop::EventLoopProxy; pub const CLIENT_SOCKET: SocketAddr = SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0)); pub struct ConnectInfo { pub url: String, pub cert: Vec, } pub struct NetHandle { send: UnboundedSender, } #[derive(Clone)] pub struct AppHandle { pub proxy: EventLoopProxy, } impl AppHandle { pub fn send(&self, event: ClientEvent) { self.proxy.send_event(event).unwrap_or_else(|_| panic!()); } } type NetResult = Result; pub trait ClientRequest {} pub enum NetCtrlMsg { Send(ClientMsg), Request(ClientMsg, oneshot::Sender), Exit, } impl NetHandle { fn send_(&self, msg: NetCtrlMsg) { let _ = self.send.send(msg); } pub fn send(&self, msg: impl Into) { self.send_(NetCtrlMsg::Send(msg.into())); } pub async fn request(&self, msg: R) -> Result { let (send, recv) = oneshot::channel(); self.send_(NetCtrlMsg::Request(msg.into(), send)); let Ok(recv) = recv.await else { return Err(()) }; if let Some(res) = R::result(recv) { Ok(res) } else { Err(()) } } pub fn exit(self) { self.send_(NetCtrlMsg::Exit); } } pub trait RequestMsg: Into { type Result; fn result(msg: ServerMsg) -> Option; } impl RequestMsg for CreateAccount { type Result = CreateAccountResp; fn result(msg: ServerMsg) -> Option { if let ServerMsg::CreateAccount(res) = msg { Some(res) } else { None } } } async fn connection_cert(addr: SocketAddr, cert: CertificateDer) -> NetResult { let mut roots = quinn::rustls::RootCertStore::empty(); roots.add(cert); let client_crypto = quinn::rustls::ClientConfig::builder() .with_root_certificates(roots) .with_no_client_auth(); let client_config = ClientConfig::new(Arc::new(QuicClientConfig::try_from(client_crypto)?)); let mut endpoint = quinn::Endpoint::client(SocketAddr::from_str("[::]:0").unwrap())?; endpoint.set_default_client_config(client_config); endpoint .connect(addr, SERVER_NAME)? .await .map_err(|e| format!("failed to connect: {}", e)) } async fn connection_no_cert(addr: SocketAddr) -> NetResult<(Endpoint, Connection)> { let mut endpoint = Endpoint::client(CLIENT_SOCKET).map_err(|e| e.to_string())?; let quic = QuicClientConfig::try_from( quinn::rustls::ClientConfig::builder() .dangerous() .with_custom_certificate_verifier(SkipServerVerification::new()) .with_no_client_auth(), ) .map_err(|e| e.to_string())?; let mut config = ClientConfig::new(Arc::new(quic)); let mut transport = TransportConfig::default(); transport.keep_alive_interval(Some(Duration::from_secs(5))); transport.max_idle_timeout(Some( IdleTimeout::try_from(Duration::from_secs(10)).unwrap(), )); config.transport_config(transport.into()); endpoint.set_default_client_config(config); // connect to server let con = endpoint .connect(addr, SERVER_NAME) .map_err(|e| e.to_string())? .await .map_err(|e| e.to_string())?; Ok((endpoint, con)) } impl NetHandle { pub async fn connect(msg: impl MsgHandler, info: ConnectInfo) -> Result { let (send, mut ui_recv) = tokio::sync::mpsc::unbounded_channel::(); let cert = CertificateDer::from_slice(&info.cert); let addr = info .url .to_socket_addrs() .map_err(|e| e.to_string())? .next() .ok_or("no addresses found".to_string())?; let (endpoint, conn) = connection_cert(addr).await?; let conn_ = conn.clone(); let mut req_id = RequestId::first(); let recv = Arc::new(ServerRecv { msg, requests: DashMap::default(), }); tokio::spawn(recv_uni(conn_, recv.clone())); tokio::spawn(async move { while let Some(msg) = ui_recv.recv().await { let request_id = req_id.next(); match msg { NetCtrlMsg::Send(msg) => { let msg = ClientRequestMsg { id: request_id, msg: msg.into(), }; if send_uni(&conn, msg).await.is_err() { println!("disconnected from server"); break; } } NetCtrlMsg::Request(msg, send) => { let msg = ClientRequestMsg { id: request_id, msg: msg.into(), }; recv.requests.insert(request_id, send); if send_uni(&conn, msg).await.is_err() { println!("disconnected from server"); break; } } NetCtrlMsg::Exit => { conn.close(0u32.into(), &[]); endpoint.wait_idle().await; break; } } } }); Ok(NetHandle { send }) } } pub trait MsgHandler: Sync + Send + 'static { fn run(&self, msg: ServerMsg) -> impl Future + Send; } impl MsgHandler for F where for<'a> F::CallRefFuture<'a>: Send, { async fn run(&self, msg: ServerMsg) { self(msg).await; } } struct ServerRecv { requests: DashMap>, msg: F, } impl RecvHandler for ServerRecv { async fn msg(&self, resp: ServerRespMsg) { let msg = resp.msg.into(); if let Some(id) = resp.id && let Some((_, send)) = self.requests.remove(&id) { send.send(msg); } else { self.msg.run(msg).await; } } }