150 lines
4.5 KiB
Rust
150 lines
4.5 KiB
Rust
use openworm::net::{
|
|
ClientMsgInst, RecvHandler, RequestId, SERVER_NAME, SendResult, ServerMsg, ServerMsgInst,
|
|
recv_uni, send_uni,
|
|
};
|
|
use quinn::{
|
|
Connection, Endpoint, ServerConfig,
|
|
rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer},
|
|
};
|
|
use std::{fs, path::Path};
|
|
use std::{
|
|
net::{Ipv6Addr, SocketAddr, SocketAddrV6},
|
|
sync::Arc,
|
|
};
|
|
use tokio::task::JoinHandle;
|
|
use tracing::Instrument;
|
|
|
|
pub const SERVER_HOST: Ipv6Addr = Ipv6Addr::UNSPECIFIED;
|
|
|
|
pub fn init_endpoint(port: u16, data_path: &Path) -> Endpoint {
|
|
let cert_path = data_path.join("cert.der");
|
|
let key_path = data_path.join("key.der");
|
|
let (cert, key) = match fs::read(&cert_path).and_then(|x| Ok((x, fs::read(&key_path)?))) {
|
|
Ok((cert, key)) => (
|
|
CertificateDer::from(cert),
|
|
PrivateKeyDer::try_from(key).unwrap(),
|
|
),
|
|
Err(ref e) if e.kind() == std::io::ErrorKind::NotFound => {
|
|
let cert = rcgen::generate_simple_self_signed([SERVER_NAME.into()]).unwrap();
|
|
let key = PrivatePkcs8KeyDer::from(cert.signing_key.serialize_der());
|
|
let cert = cert.cert.into();
|
|
fs::create_dir_all(data_path).expect("failed to create certificate directory");
|
|
fs::write(&cert_path, &cert).expect("failed to write certificate");
|
|
fs::write(&key_path, key.secret_pkcs8_der()).expect("failed to write private key");
|
|
(cert, key.into())
|
|
}
|
|
Err(e) => {
|
|
panic!("failed to read certificate: {}", e);
|
|
}
|
|
};
|
|
print!("cert hex: ");
|
|
for x in cert.iter() {
|
|
print!("{:02x}", x);
|
|
}
|
|
println!();
|
|
let server_config = ServerConfig::with_single_cert(vec![cert], key).unwrap();
|
|
let server_socket: SocketAddr = SocketAddr::V6(SocketAddrV6::new(SERVER_HOST, port, 0, 0));
|
|
quinn::Endpoint::server(server_config, server_socket).unwrap()
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub struct ClientSender {
|
|
conn: Connection,
|
|
}
|
|
|
|
impl ClientSender {
|
|
pub fn remote(&self) -> SocketAddr {
|
|
self.conn.remote_address()
|
|
}
|
|
|
|
pub fn replier(&self, id: RequestId) -> ClientReplier {
|
|
ClientReplier {
|
|
conn: self.conn.clone(),
|
|
req_id: id,
|
|
}
|
|
}
|
|
|
|
pub async fn send(&self, msg: impl Into<ServerMsg>) -> SendResult {
|
|
let msg = ServerMsgInst {
|
|
id: None,
|
|
msg: msg.into(),
|
|
};
|
|
send_uni(&self.conn, msg).await
|
|
}
|
|
}
|
|
|
|
pub struct ClientReplier {
|
|
conn: Connection,
|
|
req_id: RequestId,
|
|
}
|
|
|
|
impl ClientReplier {
|
|
pub async fn send(&self, msg: impl Into<ServerMsg>) {
|
|
let msg = ServerMsgInst {
|
|
id: Some(self.req_id),
|
|
msg: msg.into(),
|
|
};
|
|
let _ = send_uni(&self.conn, msg).await;
|
|
}
|
|
}
|
|
|
|
pub trait ConAccepter: Send + Sync + 'static {
|
|
fn accept(
|
|
&self,
|
|
send: ClientSender,
|
|
) -> impl Future<Output = impl RecvHandler<ClientMsgInst>> + Send;
|
|
}
|
|
|
|
pub fn listen(
|
|
port: u16,
|
|
data_path: &Path,
|
|
accepter: impl ConAccepter,
|
|
) -> (Endpoint, JoinHandle<()>) {
|
|
let accepter = Arc::new(accepter);
|
|
let endpoint = init_endpoint(port, data_path);
|
|
let res = endpoint.clone();
|
|
|
|
let handle = tokio::spawn(async move {
|
|
println!("listening on {}", endpoint.local_addr().unwrap());
|
|
let mut tasks = Vec::new();
|
|
while let Some(conn) = endpoint.accept().await {
|
|
let fut = handle_connection(conn, accepter.clone());
|
|
tasks.push(tokio::spawn(async move {
|
|
if let Err(e) = fut.await {
|
|
eprintln!("connection failed: {reason}", reason = e)
|
|
}
|
|
}));
|
|
}
|
|
for task in tasks {
|
|
let _ = task.await;
|
|
}
|
|
});
|
|
(res, handle)
|
|
}
|
|
|
|
async fn handle_connection(
|
|
conn: quinn::Incoming,
|
|
accepter: Arc<impl ConAccepter>,
|
|
) -> std::io::Result<()> {
|
|
let conn = conn.await?;
|
|
let handler = Arc::new(accepter.accept(ClientSender { conn: conn.clone() }).await);
|
|
handler.connect().await;
|
|
let span = tracing::info_span!(
|
|
"connection",
|
|
remote = %conn.remote_address(),
|
|
protocol = %conn
|
|
.handshake_data()
|
|
.unwrap()
|
|
.downcast::<quinn::crypto::rustls::HandshakeData>().unwrap()
|
|
.protocol
|
|
.map_or_else(|| "<none>".into(), |x| String::from_utf8_lossy(&x).into_owned())
|
|
);
|
|
async {
|
|
let res = recv_uni(conn, handler.clone()).await;
|
|
handler.disconnect(res).await;
|
|
}
|
|
.instrument(span)
|
|
.await;
|
|
Ok(())
|
|
}
|