diff --git a/src/client/app.rs b/src/client/app.rs index 94ac4eb..2a5f8b5 100644 --- a/src/client/app.rs +++ b/src/client/app.rs @@ -62,7 +62,7 @@ impl ApplicationHandler for App { impl AppHandle { pub fn send(&self, event: ClientEvent) { - self.proxy.send_event(event).unwrap(); + self.proxy.send_event(event).unwrap_or_else(|_| panic!()); self.window.request_redraw(); } } diff --git a/src/client/mod.rs b/src/client/mod.rs index 79f2e1b..e914565 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -6,8 +6,8 @@ use render::Renderer; use winit::{event::WindowEvent, event_loop::ActiveEventLoop}; use crate::{ - client::ui::{Submit, main_view}, - net::client::NetSender, + client::ui::{Submit, main_view, msg_widget}, + net::{ClientMsg, ServerMsg, client::NetSender}, }; mod app; @@ -17,9 +17,9 @@ mod ui; pub use app::AppHandle; -#[derive(Debug)] pub enum ClientEvent { - Connect(NetSender), + Connect { send: NetSender, username: String }, + ServerMsg(ServerMsg), } pub struct Client { @@ -27,6 +27,8 @@ pub struct Client { input: Input, ui: Ui, focus: Option>, + channel: Option>, + username: String, clipboard: Clipboard, } @@ -40,16 +42,36 @@ impl Client { renderer, input: Input::default(), ui, + channel: None, focus: None, + username: "".to_string(), clipboard: Clipboard::new().unwrap(), } } pub fn event(&mut self, event: ClientEvent, _: &ActiveEventLoop) { match event { - ClientEvent::Connect(send) => { - main_view(&mut self.ui, send).set_root(&mut self.ui); + ClientEvent::Connect { send, username } => { + self.username = username; + send.send(ClientMsg::RequestMsgs); + main_view(self, send).set_root(&mut self.ui); } + ClientEvent::ServerMsg(msg) => match msg { + ServerMsg::SendMsg(msg) => { + if let Some(msg_area) = &self.channel { + let msg = msg_widget(msg).add(&mut self.ui); + self.ui[msg_area].children.push(msg.any()); + } + } + ServerMsg::LoadMsgs(msgs) => { + if let Some(msg_area) = &self.channel { + for msg in msgs { + let msg = msg_widget(msg).add(&mut self.ui); + self.ui[msg_area].children.push(msg.any()); + } + } + } + }, } } diff --git a/src/client/ui.rs b/src/client/ui.rs index 65640eb..8aaa7b7 100644 --- a/src/client/ui.rs +++ b/src/client/ui.rs @@ -4,8 +4,8 @@ use len_fns::*; use crate::{ client::{Client, app::AppHandle}, net::{ - ClientMsg, - client::{NetSender, connect}, + ClientMsg, Msg, + client::{ConnectInfo, NetSender, connect}, }, }; @@ -22,17 +22,20 @@ pub fn ui(handle: AppHandle) -> Ui { ui } -pub fn main_view(ui: &mut Ui, network: NetSender) -> WidgetId { - let msg_panel = msg_panel(ui, network); +pub fn main_view(client: &mut Client, network: NetSender) -> WidgetId { + let msg_panel = msg_panel(client, network); let side_bar = rect(Color::BLACK.brighter(0.05)).width(80); - (side_bar, msg_panel).span(Dir::RIGHT).add(ui).any() + (side_bar, msg_panel) + .span(Dir::RIGHT) + .add(&mut client.ui) + .any() } fn login_screen(ui: &mut Ui, handle: AppHandle) -> WidgetId { let mut field = |name| text(name).editable().size(20).add(ui); - let ip = field("ip"); - // let username = field("username"); + let ip = field("localhost:16839"); + let username = field("username"); // let password = field("password"); let fbx = |field: WidgetId| { @@ -44,19 +47,21 @@ fn login_screen(ui: &mut Ui, handle: AppHandle) -> WidgetId { }; let ip_ = ip.clone(); + let username_ = username.clone(); let color = Color::GREEN; let submit = rect(color) .radius(15) .id_on(CursorSense::click(), move |id, client: &mut Client, _| { client.ui[id].color = color.darker(0.3); let ip = client.ui[&ip_].content(); - connect(handle.clone(), ip); + let username = client.ui[&username_].content(); + connect(handle.clone(), ConnectInfo { ip, username }); }) .height(40); ( text("login").size(30), fbx(ip), - // fbx(username), + fbx(username), // fbx(password), submit, ) @@ -70,8 +75,8 @@ fn login_screen(ui: &mut Ui, handle: AppHandle) -> WidgetId { .any() } -fn msg_widget(content: String) -> impl WidgetLike { - let content = text(content) +pub fn msg_widget(msg: Msg) -> impl WidgetLike { + let content = text(msg.content) .editable() .size(20) .text_align(Align::Left) @@ -80,7 +85,7 @@ fn msg_widget(content: String) -> impl WidgetLike { client.ui.text(id).select(ctx.cursor, ctx.size); client.focus = Some(id.clone()); }); - let header = text("some user").size(20); + let header = text(msg.user).size(20); ( image(include_bytes!("./assets/sungals.png")) .sized((70, 70)) @@ -100,14 +105,12 @@ pub fn focus(id: WidgetId) -> impl Fn(&mut Client, CursorData) { } } -pub fn msg_panel(ui: &mut Ui, network: NetSender) -> impl WidgetFn + use<> { +pub fn msg_panel(client: &mut Client, network: NetSender) -> impl WidgetFn + use<> { + let Client { ui, channel, .. } = client; let msg_area = Span::empty(Dir::DOWN).gap(15).add(ui); + *channel = Some(msg_area.clone()); - let send_text = text("some stuff idk") - .editable() - .size(20) - .text_align(Align::Left) - .add(ui); + let send_text = text("").editable().size(20).text_align(Align::Left).add(ui); ( msg_area @@ -120,12 +123,12 @@ pub fn msg_panel(ui: &mut Ui, network: NetSender) -> impl WidgetFn + use< .clone() .id_on(Submit, move |id, client: &mut Client, _| { let content = client.ui.text(id).take(); - network - .send(ClientMsg::SendMsg { - content: content.clone(), - }) - .unwrap(); - let msg = msg_widget(content).add(&mut client.ui); + let msg = Msg { + content: content.clone(), + user: client.username.clone(), + }; + network.send(ClientMsg::SendMsg(msg.clone())); + let msg = msg_widget(msg).add(&mut client.ui); client.ui[&msg_area].children.push(msg.any()); }) .pad(15) diff --git a/src/net/client.rs b/src/net/client.rs index 067054f..b0b0013 100644 --- a/src/net/client.rs +++ b/src/net/client.rs @@ -1,24 +1,42 @@ use crate::{ client::{AppHandle, ClientEvent}, - net::{BINCODE_CONFIG, ClientMsg, SERVER_NAME, no_cert::SkipServerVerification}, + net::{ + ClientMsg, SERVER_NAME, ServerMsg, + no_cert::SkipServerVerification, + transfer::{RecvHandler, recv_uni, send_uni}, + }, }; use quinn::{ClientConfig, Connection, Endpoint, crypto::rustls::QuicClientConfig}; use std::{ net::{Ipv6Addr, SocketAddr, SocketAddrV6, ToSocketAddrs}, sync::Arc, }; -use tokio::{io::AsyncWriteExt, sync::mpsc::UnboundedSender}; +use tokio::sync::mpsc::UnboundedSender; pub const CLIENT_SOCKET: SocketAddr = SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0)); -pub fn connect(handle: AppHandle, ip: String) { +pub struct ConnectInfo { + pub ip: String, + pub username: String, +} + +pub fn connect(handle: AppHandle, info: ConnectInfo) { std::thread::spawn(|| { - connect_the(handle, ip).unwrap(); + connect_the(handle, info).unwrap(); }); } -pub type NetSender = UnboundedSender; +type MsgPayload = ClientMsg; +pub struct NetSender { + send: UnboundedSender, +} + +impl NetSender { + pub fn send(&self, msg: ClientMsg) { + self.send.send(msg).unwrap(); + } +} // async fn connection_cert(addr: SocketAddr) -> anyhow::Result { // let dirs = directories_next::ProjectDirs::from("", "", "openworm").unwrap(); @@ -65,29 +83,36 @@ async fn connection_no_cert(addr: SocketAddr) -> anyhow::Result { } #[tokio::main] -async fn connect_the(handle: AppHandle, ip: String) -> anyhow::Result<()> { - let (client_send, mut client_recv) = tokio::sync::mpsc::unbounded_channel::(); +async fn connect_the(handle: AppHandle, info: ConnectInfo) -> anyhow::Result<()> { + let (send, mut ui_recv) = tokio::sync::mpsc::unbounded_channel::(); - handle.send(ClientEvent::Connect(client_send)); - let addr = ip.to_socket_addrs().unwrap().next().unwrap(); + handle.send(ClientEvent::Connect { + username: info.username, + send: NetSender { send }, + }); + let addr = info.ip.to_socket_addrs().unwrap().next().unwrap(); let conn = connection_no_cert(addr).await?; + let conn_ = conn.clone(); - while let Some(msg) = client_recv.recv().await { - let bytes = bincode::encode_to_vec(msg, BINCODE_CONFIG).unwrap(); - let (mut send, recv) = conn - .open_bi() - .await - .map_err(|e| anyhow::anyhow!("failed to open stream: {e}"))?; + let recv = ServerRecv { handle }; + tokio::spawn(recv_uni(conn_, recv)); - drop(recv); - - send.write_u64(bytes.len() as u64) - .await - .expect("failed to send"); - send.write_all(&bytes).await.expect("failed to send"); - send.finish().unwrap(); - send.stopped().await.unwrap(); + while let Some(msg) = ui_recv.recv().await { + if send_uni(&conn, msg).await.is_err() { + println!("disconnected from server"); + break; + } } Ok(()) } + +struct ServerRecv { + handle: AppHandle, +} + +impl RecvHandler for ServerRecv { + async fn msg(&self, msg: ServerMsg) { + self.handle.send(ClientEvent::ServerMsg(msg)); + } +} diff --git a/src/net/mod.rs b/src/net/mod.rs index 8beacd9..3032a8d 100644 --- a/src/net/mod.rs +++ b/src/net/mod.rs @@ -3,16 +3,27 @@ use bincode::config::Configuration; pub mod client; mod no_cert; pub mod server; +pub mod transfer; pub const SERVER_NAME: &str = "openworm"; pub const BINCODE_CONFIG: Configuration = bincode::config::standard(); #[derive(Debug, bincode::Encode, bincode::Decode)] pub enum ClientMsg { - SendMsg { content: String }, + SendMsg(Msg), + RequestMsgs, } #[derive(Debug, bincode::Encode, bincode::Decode)] pub enum ServerMsg { - RecvMsg { content: String }, + SendMsg(Msg), + LoadMsgs(Vec), +} + +pub type ServerResp = Result; + +#[derive(Debug, Clone, bincode::Encode, bincode::Decode)] +pub struct Msg { + pub content: String, + pub user: String, } diff --git a/src/net/server.rs b/src/net/server.rs index 80545b8..4079c16 100644 --- a/src/net/server.rs +++ b/src/net/server.rs @@ -1,6 +1,9 @@ -use crate::net::{BINCODE_CONFIG, ClientMsg, SERVER_NAME}; +use crate::net::{ + ClientMsg, SERVER_NAME, ServerMsg, + transfer::{RecvHandler, SendResult, recv_uni, send_uni}, +}; use quinn::{ - Endpoint, ServerConfig, + Connection, Endpoint, ServerConfig, rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer}, }; use std::{fs, path::Path}; @@ -8,7 +11,6 @@ use std::{ net::{Ipv6Addr, SocketAddr, SocketAddrV6}, sync::Arc, }; -use tokio::io::AsyncReadExt; use tracing::Instrument; pub const DEFAULT_PORT: u16 = 16839; @@ -53,17 +55,31 @@ pub fn init_endpoint(data_path: &Path) -> Endpoint { quinn::Endpoint::server(server_config, SERVER_SOCKET).unwrap() } -pub trait ConHandler: Send + Sync + 'static { - fn on_msg(&self, msg: ClientMsg); +#[derive(Clone)] +pub struct ClientSender { + conn: Connection, } -pub async fn listen(data_path: &Path, handler: impl ConHandler) { - let handler = Arc::new(handler); +impl ClientSender { + pub async fn send(&self, msg: ServerMsg) -> SendResult { + send_uni(&self.conn, msg).await + } +} + +pub trait ConAccepter: Send + Sync + 'static { + fn accept( + &self, + send: ClientSender, + ) -> impl Future> + Send; +} + +pub async fn listen(data_path: &Path, accepter: impl ConAccepter) { + let accepter = Arc::new(accepter); let endpoint = init_endpoint(data_path); println!("listening on {}", endpoint.local_addr().unwrap()); while let Some(conn) = endpoint.accept().await { - let fut = handle_connection(conn, handler.clone()); + let fut = handle_connection(conn, accepter.clone()); tokio::spawn(async move { if let Err(e) = fut.await { eprintln!("connection failed: {reason}", reason = e) @@ -74,13 +90,14 @@ pub async fn listen(data_path: &Path, handler: impl ConHandler) { async fn handle_connection( conn: quinn::Incoming, - handler: Arc, + accepter: Arc, ) -> std::io::Result<()> { - let connection = conn.await?; + let conn = conn.await?; + let handler = accepter.accept(ClientSender { conn: conn.clone() }).await; let span = tracing::info_span!( "connection", - remote = %connection.remote_address(), - protocol = %connection + remote = %conn.remote_address(), + protocol = %conn .handshake_data() .unwrap() .downcast::().unwrap() @@ -88,47 +105,9 @@ async fn handle_connection( .map_or_else(|| "".into(), |x| String::from_utf8_lossy(&x).into_owned()) ); async { - // Each stream initiated by the client constitutes a new request. - loop { - let stream = connection.accept_bi().await; - // let time = Instant::now(); - let stream = match stream { - Err(quinn::ConnectionError::ApplicationClosed { .. }) => { - println!("connection closed"); - return Ok(()); - } - Err(e) => { - return Err(e); - } - Ok(s) => s, - }; - let handler = handler.clone(); - tokio::spawn( - async move { - if let Err(e) = handle_stream(stream, handler).await { - eprintln!("failed: {reason}", reason = e); - } - } - .instrument(tracing::info_span!("request")), - ); - } + recv_uni(conn, handler).await; } .instrument(span) - .await?; - Ok(()) -} - -async fn handle_stream( - (send, mut recv): (quinn::SendStream, quinn::RecvStream), - handler: Arc, -) -> Result<(), String> { - drop(send); - let len = recv.read_u64().await.unwrap(); - let bytes = recv - .read_to_end(len as usize) - .await - .map_err(|e| format!("failed reading request: {}", e))?; - let (msg, _) = bincode::decode_from_slice::(&bytes, BINCODE_CONFIG).unwrap(); - handler.on_msg(msg); + .await; Ok(()) } diff --git a/src/net/transfer.rs b/src/net/transfer.rs new file mode 100644 index 0000000..09e7afe --- /dev/null +++ b/src/net/transfer.rs @@ -0,0 +1,48 @@ +use std::sync::Arc; + +use crate::net::BINCODE_CONFIG; +use quinn::Connection; +use tokio::io::{AsyncReadExt as _, AsyncWriteExt}; +use tracing::Instrument as _; + +pub trait RecvHandler: Send + Sync + 'static { + fn msg(&self, msg: M) -> impl Future + Send; +} + +pub type SendResult = Result<(), ()>; +pub async fn send_uni(conn: &Connection, msg: M) -> SendResult { + let bytes = bincode::encode_to_vec(msg, BINCODE_CONFIG).unwrap(); + let mut send = conn.open_uni().await.map_err(|_| ())?; + + send.write_u64(bytes.len() as u64).await.map_err(|_| ())?; + send.write_all(&bytes).await.map_err(|_| ())?; + send.finish().map_err(|_| ())?; + send.stopped().await.map_err(|_| ())?; + Ok(()) +} + +pub async fn recv_uni>(conn: Connection, handler: impl RecvHandler) { + let handler = Arc::new(handler); + loop { + let mut recv = match conn.accept_uni().await { + Err(quinn::ConnectionError::ApplicationClosed { .. }) => { + return; + } + Err(e) => { + eprintln!("connection error: {e}"); + return; + } + Ok(s) => s, + }; + let handler = handler.clone(); + tokio::spawn( + async move { + let len = recv.read_u64().await.unwrap(); + let bytes = recv.read_to_end(len as usize).await.unwrap(); + let (msg, _) = bincode::decode_from_slice::(&bytes, BINCODE_CONFIG).unwrap(); + handler.msg(msg).await; + } + .instrument(tracing::info_span!("request")), + ); + } +} diff --git a/src/server/mod.rs b/src/server/mod.rs index 2767c82..f01873e 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -1,20 +1,82 @@ use crate::net::{ - ClientMsg, - server::{ConHandler, listen}, + ClientMsg, Msg, ServerMsg, + server::{ClientSender, ConAccepter, listen}, + transfer::RecvHandler, }; +use std::{ + collections::HashMap, + sync::{ + Arc, + atomic::{AtomicU64, Ordering}, + }, +}; +use tokio::sync::RwLock; #[tokio::main] pub async fn run_server() { let dirs = directories_next::ProjectDirs::from("", "", "openworm").unwrap(); let path = dirs.data_local_dir(); - let handler = ClientHandler {}; + let handler = ServerListener { + msgs: Default::default(), + senders: Default::default(), + count: 0.into(), + }; listen(path, handler).await; } -pub struct ClientHandler {} +type ClientId = u64; -impl ConHandler for ClientHandler { - fn on_msg(&self, msg: ClientMsg) { - println!("received msg: {msg:?}"); +struct ServerListener { + msgs: Arc>>, + senders: Arc>>, + count: AtomicU64, +} + +impl ConAccepter for ServerListener { + async fn accept(&self, send: ClientSender) -> impl RecvHandler { + let id = self.count.fetch_add(1, Ordering::Release); + self.senders.write().await.insert(id, send.clone()); + ClientHandler { + msgs: self.msgs.clone(), + senders: self.senders.clone(), + send, + id, + } + } +} + +struct ClientHandler { + msgs: Arc>>, + send: ClientSender, + senders: Arc>>, + id: ClientId, +} + +impl RecvHandler for ClientHandler { + async fn msg(&self, msg: ClientMsg) { + match msg { + ClientMsg::SendMsg(msg) => { + self.msgs.write().await.push(msg.clone()); + let mut handles = Vec::new(); + for (&id, send) in self.senders.read().await.iter() { + if id == self.id { + continue; + } + let send = send.clone(); + let msg = msg.clone(); + let fut = async move { + let _ = send.send(ServerMsg::SendMsg(msg)).await; + }; + handles.push(tokio::spawn(fut)); + } + for h in handles { + h.await.unwrap(); + } + } + ClientMsg::RequestMsgs => { + let msgs = self.msgs.read().await.clone(); + let _ = self.send.send(ServerMsg::LoadMsgs(msgs)).await; + } + } } }