persistence + proper disconnect
This commit is contained in:
3
src/bin/server/data.rs
Normal file
3
src/bin/server/data.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
pub struct DBMsg {
|
||||
|
||||
}
|
||||
53
src/bin/server/db.rs
Normal file
53
src/bin/server/db.rs
Normal file
@@ -0,0 +1,53 @@
|
||||
use std::path::Path;
|
||||
|
||||
use bincode::{Decode, Encode};
|
||||
use openworm::net::BINCODE_CONFIG;
|
||||
use sled::{Db, Tree};
|
||||
|
||||
pub const DB_VERSION: u64 = 0;
|
||||
|
||||
pub fn open_db(path: impl AsRef<Path>) -> Db {
|
||||
let db = sled::open(path).expect("failed to open database");
|
||||
if !db.was_recovered() {
|
||||
println!("no previous db found, creating new");
|
||||
db.insert_("version", DB_VERSION);
|
||||
db.flush().unwrap();
|
||||
} else {
|
||||
let version: u64 = db.get_("version").expect("failed to read db version");
|
||||
println!("found existing db version {version}");
|
||||
if version != DB_VERSION {
|
||||
panic!("non matching db version! (auto update in the future)");
|
||||
}
|
||||
}
|
||||
db
|
||||
}
|
||||
|
||||
pub trait DbUtil {
|
||||
fn insert_<K: AsRef<[u8]>, V: Encode>(&self, k: K, v: V);
|
||||
fn get_<K: AsRef<[u8]>, V: Decode<()>>(&self, k: K) -> Option<V>;
|
||||
fn iter_all<V: Decode<()>>(&self) -> impl Iterator<Item = V>;
|
||||
}
|
||||
|
||||
impl DbUtil for Tree {
|
||||
fn insert_<K: AsRef<[u8]>, V: Encode>(&self, k: K, v: V) {
|
||||
let bytes = bincode::encode_to_vec(v, BINCODE_CONFIG).unwrap();
|
||||
self.insert(k, bytes).unwrap();
|
||||
}
|
||||
|
||||
fn get_<K: AsRef<[u8]>, V: Decode<()>>(&self, k: K) -> Option<V> {
|
||||
let bytes = self.get(k).unwrap()?;
|
||||
Some(
|
||||
bincode::decode_from_slice(&bytes, BINCODE_CONFIG)
|
||||
.unwrap()
|
||||
.0,
|
||||
)
|
||||
}
|
||||
|
||||
fn iter_all<V: Decode<()>>(&self) -> impl Iterator<Item = V> {
|
||||
self.iter().map(|r| {
|
||||
bincode::decode_from_slice(&r.unwrap().1, BINCODE_CONFIG)
|
||||
.unwrap()
|
||||
.0
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,10 +1,15 @@
|
||||
// mod data;
|
||||
mod db;
|
||||
mod net;
|
||||
|
||||
use crate::db::{DbUtil, open_db};
|
||||
use clap::Parser;
|
||||
use net::{ClientSender, ConAccepter, listen};
|
||||
use openworm::{
|
||||
net::{ClientMsg, DisconnectReason, Msg, RecvHandler, ServerMsg, install_crypto_provider},
|
||||
net::{ClientMsg, DisconnectReason, RecvHandler, ServerMsg, install_crypto_provider},
|
||||
rsc::DataDir,
|
||||
};
|
||||
use sled::{Db, Tree};
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
sync::{
|
||||
@@ -14,27 +19,39 @@ use std::{
|
||||
};
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// port to listen on
|
||||
#[arg(short, long)]
|
||||
port: u16,
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let args = Args::parse();
|
||||
install_crypto_provider();
|
||||
run_server();
|
||||
run_server(args.port);
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
pub async fn run_server() {
|
||||
pub async fn run_server(port: u16) {
|
||||
let dir = DataDir::default();
|
||||
let path = dir.get();
|
||||
let db: Db = open_db(path.join("server.db"));
|
||||
let handler = ServerListener {
|
||||
msgs: Default::default(),
|
||||
msgs: db.open_tree("msgs").unwrap(),
|
||||
senders: Default::default(),
|
||||
count: 0.into(),
|
||||
db,
|
||||
};
|
||||
listen(path, handler).await;
|
||||
listen(port, path, handler).await;
|
||||
}
|
||||
|
||||
type ClientId = u64;
|
||||
|
||||
struct ServerListener {
|
||||
msgs: Arc<RwLock<Vec<Msg>>>,
|
||||
db: Db,
|
||||
msgs: Tree,
|
||||
senders: Arc<RwLock<HashMap<ClientId, ClientSender>>>,
|
||||
count: AtomicU64,
|
||||
}
|
||||
@@ -44,6 +61,7 @@ impl ConAccepter for ServerListener {
|
||||
let id = self.count.fetch_add(1, Ordering::Release);
|
||||
self.senders.write().await.insert(id, send.clone());
|
||||
ClientHandler {
|
||||
db: self.db.clone(),
|
||||
msgs: self.msgs.clone(),
|
||||
senders: self.senders.clone(),
|
||||
send,
|
||||
@@ -53,17 +71,22 @@ impl ConAccepter for ServerListener {
|
||||
}
|
||||
|
||||
struct ClientHandler {
|
||||
msgs: Arc<RwLock<Vec<Msg>>>,
|
||||
db: Db,
|
||||
msgs: Tree,
|
||||
send: ClientSender,
|
||||
senders: Arc<RwLock<HashMap<ClientId, ClientSender>>>,
|
||||
id: ClientId,
|
||||
}
|
||||
|
||||
impl RecvHandler<ClientMsg> for ClientHandler {
|
||||
async fn connect(&self) -> () {
|
||||
println!("connected: {:?}", self.send.remote());
|
||||
}
|
||||
async fn msg(&self, msg: ClientMsg) {
|
||||
match msg {
|
||||
ClientMsg::SendMsg(msg) => {
|
||||
self.msgs.write().await.push(msg.clone());
|
||||
let id = self.db.generate_id().unwrap();
|
||||
self.msgs.insert_(id.to_be_bytes(), &msg);
|
||||
let mut handles = Vec::new();
|
||||
for (&id, send) in self.senders.read().await.iter() {
|
||||
if id == self.id {
|
||||
@@ -81,13 +104,14 @@ impl RecvHandler<ClientMsg> for ClientHandler {
|
||||
}
|
||||
}
|
||||
ClientMsg::RequestMsgs => {
|
||||
let msgs = self.msgs.read().await.clone();
|
||||
let msgs = self.msgs.iter_all().collect();
|
||||
let _ = self.send.send(ServerMsg::LoadMsgs(msgs)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn disconnect(&self, reason: DisconnectReason) -> () {
|
||||
println!("disconnected: {:?}", self.send.remote());
|
||||
match reason {
|
||||
DisconnectReason::Closed | DisconnectReason::Timeout => (),
|
||||
DisconnectReason::Other(e) => println!("connection issue: {e}"),
|
||||
|
||||
@@ -12,12 +12,9 @@ use std::{
|
||||
};
|
||||
use tracing::Instrument;
|
||||
|
||||
pub const DEFAULT_PORT: u16 = 16839;
|
||||
pub const SERVER_HOST: Ipv6Addr = Ipv6Addr::UNSPECIFIED;
|
||||
pub const SERVER_SOCKET: SocketAddr =
|
||||
SocketAddr::V6(SocketAddrV6::new(SERVER_HOST, DEFAULT_PORT, 0, 0));
|
||||
|
||||
pub fn init_endpoint(data_path: &Path) -> Endpoint {
|
||||
pub fn init_endpoint(port: u16, data_path: &Path) -> Endpoint {
|
||||
let cert_path = data_path.join("cert.der");
|
||||
let key_path = data_path.join("key.der");
|
||||
let (cert, key) = match fs::read(&cert_path).and_then(|x| Ok((x, fs::read(&key_path)?))) {
|
||||
@@ -51,7 +48,8 @@ pub fn init_endpoint(data_path: &Path) -> Endpoint {
|
||||
// let transport_config = Arc::get_mut(&mut server_config.transport).unwrap();
|
||||
// transport_config.max_concurrent_uni_streams(0_u8.into());
|
||||
|
||||
quinn::Endpoint::server(server_config, SERVER_SOCKET).unwrap()
|
||||
let server_socket: SocketAddr = SocketAddr::V6(SocketAddrV6::new(SERVER_HOST, port, 0, 0));
|
||||
quinn::Endpoint::server(server_config, server_socket).unwrap()
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -60,6 +58,9 @@ pub struct ClientSender {
|
||||
}
|
||||
|
||||
impl ClientSender {
|
||||
pub fn remote(&self) -> SocketAddr {
|
||||
self.conn.remote_address()
|
||||
}
|
||||
pub async fn send(&self, msg: ServerMsg) -> SendResult {
|
||||
send_uni(&self.conn, msg).await
|
||||
}
|
||||
@@ -72,9 +73,9 @@ pub trait ConAccepter: Send + Sync + 'static {
|
||||
) -> impl Future<Output = impl RecvHandler<ClientMsg>> + Send;
|
||||
}
|
||||
|
||||
pub async fn listen(data_path: &Path, accepter: impl ConAccepter) {
|
||||
pub async fn listen(port: u16, data_path: &Path, accepter: impl ConAccepter) {
|
||||
let accepter = Arc::new(accepter);
|
||||
let endpoint = init_endpoint(data_path);
|
||||
let endpoint = init_endpoint(port, data_path);
|
||||
println!("listening on {}", endpoint.local_addr().unwrap());
|
||||
|
||||
while let Some(conn) = endpoint.accept().await {
|
||||
@@ -93,6 +94,7 @@ async fn handle_connection(
|
||||
) -> std::io::Result<()> {
|
||||
let conn = conn.await?;
|
||||
let handler = Arc::new(accepter.accept(ClientSender { conn: conn.clone() }).await);
|
||||
handler.connect().await;
|
||||
let span = tracing::info_span!(
|
||||
"connection",
|
||||
remote = %conn.remote_address(),
|
||||
|
||||
Reference in New Issue
Block a user