persistence + proper disconnect

This commit is contained in:
2025-11-28 17:29:33 -05:00
parent 029d62cb53
commit 7557507f27
16 changed files with 413 additions and 67 deletions

View File

@@ -1,8 +1,7 @@
use openworm::net::Msg;
use crate::Client;
pub fn debug(_client: &mut Client) {
pub fn debug(client: &mut Client) {
client.ui.debug_layers();
// let mut file = std::fs::OpenOptions::new()
// .write(true)
// .create(true)
@@ -15,14 +14,14 @@ pub fn debug(_client: &mut Client) {
// openworm::net::BINCODE_CONFIG,
// )
// .unwrap();
let mut file = std::fs::OpenOptions::new()
.read(true)
.open("./old_msgs")
.unwrap();
let msgs: Vec<Msg> =
bincode::decode_from_std_read(&mut file, openworm::net::BINCODE_CONFIG).unwrap();
for msg in msgs {
println!("{msg:?}");
}
// let mut file = std::fs::OpenOptions::new()
// .read(true)
// .open("./old_msgs")
// .unwrap();
// let msgs: Vec<NetMsg> =
// bincode::decode_from_std_read(&mut file, openworm::net::BINCODE_CONFIG).unwrap();
// for msg in msgs {
// println!("{msg:?}");
// }
// client.ui.debug_layers();
}

View File

@@ -2,7 +2,7 @@
use crate::{
app::App,
net::NetSender,
net::{NetCtrlMsg, NetHandle, NetSender, NetState},
rsc::{CLIENT_DATA, ClientData},
ui::*,
};
@@ -11,7 +11,7 @@ use arboard::Clipboard;
use input::Input;
use iris::prelude::*;
use openworm::{
net::{ClientMsg, Msg, ServerMsg, install_crypto_provider},
net::{ClientMsg, NetMsg, ServerMsg, install_crypto_provider},
rsc::DataDir,
};
use render::Renderer;
@@ -53,7 +53,8 @@ pub struct Client {
data: ClientData,
handle: AppHandle,
error: Option<WidgetId<WidgetPtr>>,
msgs: Vec<Msg>,
net: NetState,
msgs: Vec<NetMsg>,
ime: usize,
last_click: Instant,
}
@@ -78,6 +79,7 @@ impl Client {
dir,
channel: None,
focus: None,
net: Default::default(),
username: "<unknown>".to_string(),
clipboard: Clipboard::new().unwrap(),
error: None,
@@ -94,6 +96,13 @@ impl Client {
ClientEvent::Connect { send, username } => {
self.username = username;
send.send(ClientMsg::RequestMsgs);
let NetState::Connecting(th) = self.net.take() else {
panic!("invalid state");
};
self.net = NetState::Connected(NetHandle {
send: send.clone(),
thread: th,
});
main_view(self, send).set_root(&mut self.ui);
}
ClientEvent::ServerMsg(msg) => match msg {
@@ -213,6 +222,10 @@ impl Client {
}
pub fn exit(&mut self) {
if let Some(handle) = self.net.take_connection() {
handle.send.send(NetCtrlMsg::Exit);
let _ = handle.thread.join();
}
self.dir.save(CLIENT_DATA, &self.data);
}
}

View File

@@ -9,6 +9,7 @@ use quinn::{
use std::{
net::{Ipv6Addr, SocketAddr, SocketAddrV6, ToSocketAddrs},
sync::Arc,
thread::JoinHandle,
time::Duration,
};
use tokio::sync::mpsc::UnboundedSender;
@@ -21,24 +22,69 @@ pub struct ConnectInfo {
pub username: String,
}
pub fn connect(handle: AppHandle, info: ConnectInfo) {
pub struct NetHandle {
pub send: NetSender,
pub thread: JoinHandle<()>,
}
#[derive(Default)]
pub enum NetState {
#[default]
None,
Connecting(JoinHandle<()>),
Connected(NetHandle),
}
impl NetState {
pub fn take_connection(&mut self) -> Option<NetHandle> {
match self.take() {
NetState::Connected(net_handle) => Some(net_handle),
_ => None,
}
}
pub fn connection(&self) -> Option<&NetHandle> {
match self {
NetState::Connected(net_handle) => Some(net_handle),
_ => None,
}
}
pub fn take(&mut self) -> Self {
std::mem::replace(self, Self::None)
}
}
pub fn connect(handle: AppHandle, info: ConnectInfo) -> JoinHandle<()> {
std::thread::spawn(move || {
if let Err(msg) = connect_the(handle.clone(), info) {
handle.send(ClientEvent::Err(msg));
}
});
})
}
type NetResult<T> = Result<T, String>;
type MsgPayload = ClientMsg;
#[derive(Clone)]
pub struct NetSender {
send: UnboundedSender<MsgPayload>,
send: UnboundedSender<NetCtrlMsg>,
}
pub enum NetCtrlMsg {
Send(ClientMsg),
Exit,
}
impl From<ClientMsg> for NetCtrlMsg {
fn from(value: ClientMsg) -> Self {
Self::Send(value)
}
}
impl NetSender {
pub fn send(&self, msg: ClientMsg) {
let _ = self.send.send(msg);
pub fn send(&self, msg: impl Into<NetCtrlMsg>) {
let _ = self.send.send(msg.into());
}
}
@@ -68,7 +114,7 @@ impl NetSender {
// .map_err(|e| format!("failed to connect: {}", e))
// }
async fn connection_no_cert(addr: SocketAddr) -> NetResult<Connection> {
async fn connection_no_cert(addr: SocketAddr) -> NetResult<(Endpoint, Connection)> {
let mut endpoint = Endpoint::client(CLIENT_SOCKET).map_err(|e| e.to_string())?;
let quic = QuicClientConfig::try_from(
@@ -91,16 +137,17 @@ async fn connection_no_cert(addr: SocketAddr) -> NetResult<Connection> {
endpoint.set_default_client_config(config);
// connect to server
endpoint
let con = endpoint
.connect(addr, SERVER_NAME)
.map_err(|e| e.to_string())?
.await
.map_err(|e| e.to_string())
.map_err(|e| e.to_string())?;
Ok((endpoint, con))
}
#[tokio::main]
async fn connect_the(handle: AppHandle, info: ConnectInfo) -> NetResult<()> {
let (send, mut ui_recv) = tokio::sync::mpsc::unbounded_channel::<MsgPayload>();
let (send, mut ui_recv) = tokio::sync::mpsc::unbounded_channel::<NetCtrlMsg>();
let addr = info
.ip
@@ -108,7 +155,7 @@ async fn connect_the(handle: AppHandle, info: ConnectInfo) -> NetResult<()> {
.map_err(|e| e.to_string())?
.next()
.ok_or("no addresses found".to_string())?;
let conn = connection_no_cert(addr).await?;
let (endpoint, conn) = connection_no_cert(addr).await?;
let conn_ = conn.clone();
handle.send(ClientEvent::Connect {
@@ -120,9 +167,18 @@ async fn connect_the(handle: AppHandle, info: ConnectInfo) -> NetResult<()> {
tokio::spawn(recv_uni(conn_, recv.into()));
while let Some(msg) = ui_recv.recv().await {
if send_uni(&conn, msg).await.is_err() {
println!("disconnected from server");
break;
match msg {
NetCtrlMsg::Send(msg) => {
if send_uni(&conn, msg).await.is_err() {
println!("disconnected from server");
break;
}
}
NetCtrlMsg::Exit => {
conn.close(quinn::VarInt::from_u32(0), &[]);
endpoint.wait_idle().await;
break;
}
}
}

View File

@@ -1,3 +1,5 @@
use crate::net::NetState;
use super::*;
pub fn login_screen(client: &mut Client) -> WidgetId {
@@ -5,13 +7,7 @@ pub fn login_screen(client: &mut Client) -> WidgetId {
ui, handle, data, ..
} = client;
let mut field = |name, hint_| {
text(name)
.editable(true)
.size(20)
.hint(hint(hint_))
.add(ui)
};
let mut field = |name, hint_| text(name).editable(true).size(20).hint(hint(hint_)).add(ui);
let ip = field(&data.ip, "ip");
let username = field(&data.username, "username");
// let password = field("password");
@@ -35,7 +31,8 @@ pub fn login_screen(client: &mut Client) -> WidgetId {
client.ui[id].color = color.darker(0.3);
let ip = client.ui[&ip_].content();
let username = client.ui[&username_].content();
connect(handle.clone(), ConnectInfo { ip, username });
let th = connect(handle.clone(), ConnectInfo { ip, username });
client.net = NetState::Connecting(th);
})
.height(40);
let modal = (

View File

@@ -18,7 +18,7 @@ pub fn main_view(client: &mut Client, network: NetSender) -> WidgetId {
.any()
}
pub fn msg_widget(msg: Msg) -> impl WidgetLike<FnTag> {
pub fn msg_widget(msg: NetMsg) -> impl WidgetLike<FnTag> {
let content = text(msg.content)
.editable(false)
.size(SIZE)
@@ -61,7 +61,7 @@ pub fn msg_panel(client: &mut Client, network: NetSender) -> impl WidgetFn<Sized
.clone()
.id_on(Submit, move |id, client: &mut Client, _| {
let content = client.ui.text(id).take();
let msg = Msg {
let msg = NetMsg {
content: content.clone(),
user: client.username.clone(),
};

View File

@@ -5,7 +5,7 @@ use crate::{
};
use iris::prelude::*;
use len_fns::*;
use openworm::net::{ClientMsg, Msg};
use openworm::net::{ClientMsg, NetMsg};
use winit::dpi::{LogicalPosition, LogicalSize};
mod login;

3
src/bin/server/data.rs Normal file
View File

@@ -0,0 +1,3 @@
pub struct DBMsg {
}

53
src/bin/server/db.rs Normal file
View File

@@ -0,0 +1,53 @@
use std::path::Path;
use bincode::{Decode, Encode};
use openworm::net::BINCODE_CONFIG;
use sled::{Db, Tree};
pub const DB_VERSION: u64 = 0;
pub fn open_db(path: impl AsRef<Path>) -> Db {
let db = sled::open(path).expect("failed to open database");
if !db.was_recovered() {
println!("no previous db found, creating new");
db.insert_("version", DB_VERSION);
db.flush().unwrap();
} else {
let version: u64 = db.get_("version").expect("failed to read db version");
println!("found existing db version {version}");
if version != DB_VERSION {
panic!("non matching db version! (auto update in the future)");
}
}
db
}
pub trait DbUtil {
fn insert_<K: AsRef<[u8]>, V: Encode>(&self, k: K, v: V);
fn get_<K: AsRef<[u8]>, V: Decode<()>>(&self, k: K) -> Option<V>;
fn iter_all<V: Decode<()>>(&self) -> impl Iterator<Item = V>;
}
impl DbUtil for Tree {
fn insert_<K: AsRef<[u8]>, V: Encode>(&self, k: K, v: V) {
let bytes = bincode::encode_to_vec(v, BINCODE_CONFIG).unwrap();
self.insert(k, bytes).unwrap();
}
fn get_<K: AsRef<[u8]>, V: Decode<()>>(&self, k: K) -> Option<V> {
let bytes = self.get(k).unwrap()?;
Some(
bincode::decode_from_slice(&bytes, BINCODE_CONFIG)
.unwrap()
.0,
)
}
fn iter_all<V: Decode<()>>(&self) -> impl Iterator<Item = V> {
self.iter().map(|r| {
bincode::decode_from_slice(&r.unwrap().1, BINCODE_CONFIG)
.unwrap()
.0
})
}
}

View File

@@ -1,10 +1,15 @@
// mod data;
mod db;
mod net;
use crate::db::{DbUtil, open_db};
use clap::Parser;
use net::{ClientSender, ConAccepter, listen};
use openworm::{
net::{ClientMsg, DisconnectReason, Msg, RecvHandler, ServerMsg, install_crypto_provider},
net::{ClientMsg, DisconnectReason, RecvHandler, ServerMsg, install_crypto_provider},
rsc::DataDir,
};
use sled::{Db, Tree};
use std::{
collections::HashMap,
sync::{
@@ -14,27 +19,39 @@ use std::{
};
use tokio::sync::RwLock;
#[derive(Parser, Debug)]
#[command(version, about, long_about = None)]
struct Args {
/// port to listen on
#[arg(short, long)]
port: u16,
}
fn main() {
let args = Args::parse();
install_crypto_provider();
run_server();
run_server(args.port);
}
#[tokio::main]
pub async fn run_server() {
pub async fn run_server(port: u16) {
let dir = DataDir::default();
let path = dir.get();
let db: Db = open_db(path.join("server.db"));
let handler = ServerListener {
msgs: Default::default(),
msgs: db.open_tree("msgs").unwrap(),
senders: Default::default(),
count: 0.into(),
db,
};
listen(path, handler).await;
listen(port, path, handler).await;
}
type ClientId = u64;
struct ServerListener {
msgs: Arc<RwLock<Vec<Msg>>>,
db: Db,
msgs: Tree,
senders: Arc<RwLock<HashMap<ClientId, ClientSender>>>,
count: AtomicU64,
}
@@ -44,6 +61,7 @@ impl ConAccepter for ServerListener {
let id = self.count.fetch_add(1, Ordering::Release);
self.senders.write().await.insert(id, send.clone());
ClientHandler {
db: self.db.clone(),
msgs: self.msgs.clone(),
senders: self.senders.clone(),
send,
@@ -53,17 +71,22 @@ impl ConAccepter for ServerListener {
}
struct ClientHandler {
msgs: Arc<RwLock<Vec<Msg>>>,
db: Db,
msgs: Tree,
send: ClientSender,
senders: Arc<RwLock<HashMap<ClientId, ClientSender>>>,
id: ClientId,
}
impl RecvHandler<ClientMsg> for ClientHandler {
async fn connect(&self) -> () {
println!("connected: {:?}", self.send.remote());
}
async fn msg(&self, msg: ClientMsg) {
match msg {
ClientMsg::SendMsg(msg) => {
self.msgs.write().await.push(msg.clone());
let id = self.db.generate_id().unwrap();
self.msgs.insert_(id.to_be_bytes(), &msg);
let mut handles = Vec::new();
for (&id, send) in self.senders.read().await.iter() {
if id == self.id {
@@ -81,13 +104,14 @@ impl RecvHandler<ClientMsg> for ClientHandler {
}
}
ClientMsg::RequestMsgs => {
let msgs = self.msgs.read().await.clone();
let msgs = self.msgs.iter_all().collect();
let _ = self.send.send(ServerMsg::LoadMsgs(msgs)).await;
}
}
}
async fn disconnect(&self, reason: DisconnectReason) -> () {
println!("disconnected: {:?}", self.send.remote());
match reason {
DisconnectReason::Closed | DisconnectReason::Timeout => (),
DisconnectReason::Other(e) => println!("connection issue: {e}"),

View File

@@ -12,12 +12,9 @@ use std::{
};
use tracing::Instrument;
pub const DEFAULT_PORT: u16 = 16839;
pub const SERVER_HOST: Ipv6Addr = Ipv6Addr::UNSPECIFIED;
pub const SERVER_SOCKET: SocketAddr =
SocketAddr::V6(SocketAddrV6::new(SERVER_HOST, DEFAULT_PORT, 0, 0));
pub fn init_endpoint(data_path: &Path) -> Endpoint {
pub fn init_endpoint(port: u16, data_path: &Path) -> Endpoint {
let cert_path = data_path.join("cert.der");
let key_path = data_path.join("key.der");
let (cert, key) = match fs::read(&cert_path).and_then(|x| Ok((x, fs::read(&key_path)?))) {
@@ -51,7 +48,8 @@ pub fn init_endpoint(data_path: &Path) -> Endpoint {
// let transport_config = Arc::get_mut(&mut server_config.transport).unwrap();
// transport_config.max_concurrent_uni_streams(0_u8.into());
quinn::Endpoint::server(server_config, SERVER_SOCKET).unwrap()
let server_socket: SocketAddr = SocketAddr::V6(SocketAddrV6::new(SERVER_HOST, port, 0, 0));
quinn::Endpoint::server(server_config, server_socket).unwrap()
}
#[derive(Clone)]
@@ -60,6 +58,9 @@ pub struct ClientSender {
}
impl ClientSender {
pub fn remote(&self) -> SocketAddr {
self.conn.remote_address()
}
pub async fn send(&self, msg: ServerMsg) -> SendResult {
send_uni(&self.conn, msg).await
}
@@ -72,9 +73,9 @@ pub trait ConAccepter: Send + Sync + 'static {
) -> impl Future<Output = impl RecvHandler<ClientMsg>> + Send;
}
pub async fn listen(data_path: &Path, accepter: impl ConAccepter) {
pub async fn listen(port: u16, data_path: &Path, accepter: impl ConAccepter) {
let accepter = Arc::new(accepter);
let endpoint = init_endpoint(data_path);
let endpoint = init_endpoint(port, data_path);
println!("listening on {}", endpoint.local_addr().unwrap());
while let Some(conn) = endpoint.accept().await {
@@ -93,6 +94,7 @@ async fn handle_connection(
) -> std::io::Result<()> {
let conn = conn.await?;
let handler = Arc::new(accepter.accept(ClientSender { conn: conn.clone() }).await);
handler.connect().await;
let span = tracing::info_span!(
"connection",
remote = %conn.remote_address(),

View File

@@ -11,20 +11,20 @@ pub const BINCODE_CONFIG: Configuration = bincode::config::standard();
#[derive(Debug, bincode::Encode, bincode::Decode)]
pub enum ClientMsg {
SendMsg(Msg),
SendMsg(NetMsg),
RequestMsgs,
}
#[derive(Debug, bincode::Encode, bincode::Decode)]
pub enum ServerMsg {
SendMsg(Msg),
LoadMsgs(Vec<Msg>),
SendMsg(NetMsg),
LoadMsgs(Vec<NetMsg>),
}
pub type ServerResp<T> = Result<T, String>;
#[derive(Debug, Clone, bincode::Encode, bincode::Decode)]
pub struct Msg {
pub struct NetMsg {
pub content: String,
pub user: String,
}

View File

@@ -6,9 +6,13 @@ use tokio::io::{AsyncReadExt as _, AsyncWriteExt};
use tracing::Instrument as _;
pub trait RecvHandler<M>: Send + Sync + 'static {
fn connect(&self) -> impl Future<Output = ()> + Send {
async {}
}
fn msg(&self, msg: M) -> impl Future<Output = ()> + Send;
#[allow(unused)]
fn disconnect(&self, reason: DisconnectReason) -> impl Future<Output = ()> + Send {
async { drop(reason) }
async {}
}
}
@@ -39,6 +43,7 @@ pub async fn recv_uni<M: bincode::Decode<()>>(
Err(quinn::ConnectionError::ApplicationClosed { .. }) => {
return DisconnectReason::Closed;
}
Err(quinn::ConnectionError::LocallyClosed) => return DisconnectReason::Closed,
Err(quinn::ConnectionError::TimedOut) => {
return DisconnectReason::Timeout;
}

View File

@@ -7,6 +7,7 @@ use directories_next::ProjectDirs;
use crate::net::BINCODE_CONFIG;
#[derive(Clone)]
pub struct DataDir {
dirs: ProjectDirs,
}