refactor: migrate net operation to KCP

This commit is contained in:
amizing25 2024-10-23 04:49:46 +07:00
parent c4794570d1
commit f77bcb04f7
13 changed files with 1946 additions and 81 deletions

View File

@ -31,3 +31,4 @@ proto.workspace = true
proto-derive.workspace = true
rand.workspace = true
mhy-kcp.workspace = true

View File

@ -1,3 +1,4 @@
#![feature(let_chains)]
use anyhow::Result;
mod logging;
@ -6,11 +7,11 @@ mod tools;
mod util;
use logging::init_tracing;
use net::gateway::Gateway;
#[tokio::main]
async fn main() -> Result<()> {
init_tracing();
net::gateway::listen("0.0.0.0", 23301).await?;
Ok(())
let mut gateway = Gateway::new("0.0.0.0", 23301).await?;
Box::pin(gateway.listen()).await
}

View File

@ -1,28 +1,143 @@
use std::{
collections::HashMap,
net::SocketAddr,
sync::{
atomic::{AtomicU32, Ordering},
Arc,
},
};
use anyhow::Result;
use tokio::net::TcpListener;
use tracing::{info_span, Instrument};
use rand::RngCore;
use crate::{log_error, net::PlayerSession};
use crate::net::PlayerSession;
pub async fn listen(host: &str, port: u16) -> Result<()> {
let listener = TcpListener::bind(format!("{host}:{port}")).await?;
tracing::info!("Listening at {host}:{port}");
use tokio::{
net::UdpSocket,
sync::{Mutex, RwLock},
};
loop {
let Ok((client_socket, client_addr)) = listener.accept().await else {
continue;
use crate::net::packet::NetOperation;
const MAX_PACKET_SIZE: usize = 1200;
pub struct Gateway {
socket: Arc<UdpSocket>,
id_counter: AtomicU32,
sessions: Mutex<HashMap<u32, Arc<RwLock<PlayerSession>>>>,
}
impl Gateway {
pub async fn new(host: &str, port: u16) -> Result<Self> {
let socket = Arc::new(UdpSocket::bind(format!("{host}:{port}")).await?);
Ok(Self {
socket,
id_counter: AtomicU32::new(0),
sessions: Mutex::new(HashMap::new()),
})
}
pub async fn listen(&mut self) -> Result<()> {
tracing::info!(
"KCP Gateway is listening at {}",
self.socket.local_addr().unwrap()
);
let mut buf = [0; MAX_PACKET_SIZE];
loop {
let Ok((len, addr)) = self.socket.recv_from(&mut buf).await else {
continue;
};
match len {
20 => self.process_net_operation(buf[..len].into(), addr).await?,
28.. => self.process_kcp_payload(buf[..len].into(), addr).await,
_ => {
tracing::warn!("unk data len {len}")
}
}
}
}
async fn process_net_operation(&mut self, op: NetOperation, addr: SocketAddr) -> Result<()> {
match (op.head, op.tail) {
(0xFF, 0xFFFFFFFF) => self.establish_kcp_session(op.data, addr).await?,
(0x194, 0x19419494) => self.drop_kcp_session(op.param1, op.param2, addr).await,
_ => tracing::warn!("Unknown magic pair received {:X}-{:X}", op.head, op.tail),
}
Ok(())
}
async fn establish_kcp_session(&mut self, data: u32, addr: SocketAddr) -> Result<()> {
tracing::info!("New connection from {addr}");
let (conv_id, session_token) = self.next_conv_pair();
self.sessions.lock().await.insert(
conv_id,
Arc::new(RwLock::new(PlayerSession::new(
self.socket.clone(),
addr,
conv_id,
session_token,
))),
);
self.socket
.send_to(
&Vec::from(NetOperation {
head: 0x145,
param1: conv_id,
param2: session_token,
data,
tail: 0x14514545,
}),
addr,
)
.await?;
Ok(())
}
async fn drop_kcp_session(&mut self, conv_id: u32, token: u32, addr: SocketAddr) {
tracing::info!("drop_kcp_session {conv_id} {token}");
let Some(session) = self.sessions.lock().await.get(&conv_id).cloned() else {
tracing::warn!("drop_kcp_session failed, no session with conv_id {conv_id} was found");
return;
};
let mut session = PlayerSession::new(client_socket);
tokio::spawn(
async move {
log_error!(
"Session from {client_addr} disconnected",
format!("An error occurred while processing session ({client_addr})"),
Box::pin(session.run()).await
);
if session.read().await.token == token {
self.sessions.lock().await.remove(&conv_id);
tracing::info!("Client from {addr} disconnected");
}
}
async fn process_kcp_payload(&mut self, data: Box<[u8]>, addr: SocketAddr) {
let conv_id = mhy_kcp::get_conv(&data);
let Some(session) = self
.sessions
.lock()
.await
.get_mut(&conv_id)
.map(|s| s.clone())
else {
tracing::warn!("Session with conv_id {conv_id} not found!");
return;
};
tokio::spawn(async move {
if let Err(err) = Box::pin(session.write().await.consume(&data)).await {
tracing::error!("An error occurred while processing session ({addr}): {err}");
}
.instrument(info_span!("session", addr = %client_addr)),
);
});
}
fn next_conv_pair(&mut self) -> (u32, u32) {
(
self.id_counter.fetch_add(1, Ordering::SeqCst) + 1,
rand::thread_rng().next_u32(),
)
}
}

View File

@ -67,7 +67,18 @@ async fn create_battle_info(caster_id: u32, skill_index: u32) -> SceneBattleInfo
// avatars
for (avatar_index, avatar_id) in player.lineups.iter() {
if let Some(avatar) = player.avatars.get(avatar_id) {
let is_trailblazer = *avatar_id == 8001;
let is_march = *avatar_id == 1001;
let avatar_id = if is_trailblazer {
player.main_character as u32
} else if is_march {
player.march_type as u32
} else {
*avatar_id
};
if let Some(avatar) = player.avatars.get(&avatar_id) {
let (battle_avatar, techs) = avatar.to_battle_avatar_proto(
*avatar_index,
player
@ -84,31 +95,22 @@ async fn create_battle_info(caster_id: u32, skill_index: u32) -> SceneBattleInfo
battle_info.buff_list.push(tech);
}
if caster_id > 0 && *avatar_index == (caster_id - 1) {
let is_trailblazer = *avatar_id == 8001;
let is_march = *avatar_id == 1001;
let avatar_id = if is_trailblazer {
player.main_character as u32
} else if is_march {
player.march_type as u32
} else {
*avatar_id
};
if let Some(avatar_config) = GAME_RES.avatar_configs.get(&avatar_id) {
battle_info.buff_list.push(BattleBuff {
id: avatar_config.weakness_buff_id,
level: 1,
owner_id: *avatar_index,
wave_flag: 0xffffffff,
dynamic_values: HashMap::from([(
String::from("SkillIndex"),
skill_index as f32,
)]),
..Default::default()
});
}
if caster_id > 0
&& *avatar_index == (caster_id - 1)
&& let Some(avatar_config) = GAME_RES.avatar_configs.get(&avatar_id)
&& !avatar.techniques.contains(&1000119)
{
battle_info.buff_list.push(BattleBuff {
id: avatar_config.weakness_buff_id,
level: 1,
owner_id: *avatar_index,
wave_flag: 0xffffffff,
dynamic_values: HashMap::from([(
String::from("SkillIndex"),
skill_index as f32,
)]),
..Default::default()
});
}
battle_info.battle_avatar_list.push(battle_avatar);

View File

@ -13,7 +13,6 @@ mod scene;
use anyhow::Result;
use paste::paste;
use proto::*;
use tokio::io::AsyncWriteExt;
use super::PlayerSession;
use crate::net::NetPacket;
@ -75,14 +74,11 @@ macro_rules! dummy {
_ => return Err(anyhow::anyhow!("Invalid request id {req_id:?}")),
};
let payload: Vec<u8> = NetPacket {
self.send_raw(NetPacket {
cmd_type,
head: Vec::new(),
body: Vec::new(),
}
.into();
self.client_socket.write_all(&payload).await?;
}).await?;
Ok(())
}

View File

@ -1,7 +1,5 @@
use anyhow::Result;
use paste::paste;
use tokio::io::AsyncReadExt;
use tokio::net::TcpStream;
use tracing::Instrument;
use proto::*;
@ -34,6 +32,16 @@ use super::PlayerSession;
const HEAD_MAGIC: u32 = 0x9D74C714;
const TAIL_MAGIC: u32 = 0xD7A152C8;
#[derive(Debug)]
pub struct NetOperation {
pub head: u32,
pub param1: u32,
pub param2: u32,
pub data: u32,
pub tail: u32,
}
#[derive(Debug)]
pub struct NetPacket {
pub cmd_type: u16,
pub head: Vec<u8>,
@ -55,27 +63,62 @@ impl From<NetPacket> for Vec<u8> {
}
}
impl NetPacket {
pub async fn read(stream: &mut TcpStream) -> std::io::Result<Self> {
assert_eq!(stream.read_u32().await?, HEAD_MAGIC);
let cmd_type = stream.read_u16().await?;
impl From<&[u8]> for NetPacket {
fn from(value: &[u8]) -> Self {
assert_eq!(
u32::from_be_bytes(value[0..4].try_into().unwrap()),
HEAD_MAGIC
);
let head_length = stream.read_u16().await? as usize;
let body_length = stream.read_u32().await? as usize;
let cmd_type = u16::from_be_bytes(value[4..6].try_into().unwrap());
let mut head = vec![0; head_length];
stream.read_exact(&mut head).await?;
let head_length = usize::from(u16::from_be_bytes(value[6..8].try_into().unwrap()));
let mut body = vec![0; body_length];
stream.read_exact(&mut body).await?;
let body_length = u32::from_be_bytes(value[8..12].try_into().unwrap()) as usize;
assert_eq!(stream.read_u32().await?, TAIL_MAGIC);
let head_start = 12;
let head_end = head_start + head_length;
let head = value[head_start..head_end].to_vec();
Ok(Self {
let body_start = head_end;
let body_end = body_start + body_length;
let body = value[body_start..body_end].to_vec();
assert_eq!(
u32::from_be_bytes(value[body_end..body_end + 4].try_into().unwrap()),
TAIL_MAGIC
);
Self {
cmd_type,
head,
body,
})
}
}
}
impl From<&[u8]> for NetOperation {
fn from(value: &[u8]) -> Self {
Self {
head: u32::from_be_bytes(value[..4].try_into().unwrap()),
param1: u32::from_be_bytes(value[4..8].try_into().unwrap()),
param2: u32::from_be_bytes(value[8..12].try_into().unwrap()),
data: u32::from_be_bytes(value[12..16].try_into().unwrap()),
tail: u32::from_be_bytes(value[16..20].try_into().unwrap()),
}
}
}
impl From<NetOperation> for Vec<u8> {
fn from(value: NetOperation) -> Self {
let mut buf = Self::with_capacity(20);
buf.extend(value.head.to_be_bytes());
buf.extend(value.param1.to_be_bytes());
buf.extend(value.param2.to_be_bytes());
buf.extend(value.data.to_be_bytes());
buf.extend(value.tail.to_be_bytes());
buf
}
}

View File

@ -1,27 +1,73 @@
use std::{
io::Error,
net::SocketAddr,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use anyhow::Result;
use mhy_kcp::Kcp;
use prost::Message;
use proto::CmdID;
use tokio::{io::AsyncWriteExt, net::TcpStream};
use tokio::{io::AsyncWrite, net::UdpSocket, sync::Mutex};
use crate::util;
use super::{packet::CommandHandler, NetPacket};
struct RemoteEndPoint {
socket: Arc<UdpSocket>,
addr: SocketAddr,
}
pub struct PlayerSession {
pub(crate) client_socket: TcpStream,
pub token: u32,
kcp: Mutex<Kcp<RemoteEndPoint>>,
start_time: u64,
}
impl PlayerSession {
pub const fn new(client_socket: TcpStream) -> Self {
Self { client_socket }
}
pub async fn run(&mut self) -> Result<()> {
loop {
let net_packet = NetPacket::read(&mut self.client_socket).await?;
Self::on_message(self, net_packet.cmd_type, net_packet.body).await?;
pub fn new(socket: Arc<UdpSocket>, addr: SocketAddr, conv: u32, token: u32) -> Self {
Self {
token,
kcp: Mutex::new(Kcp::new(
conv,
token,
false,
RemoteEndPoint { socket, addr },
)),
start_time: util::cur_timestamp_secs(),
}
}
pub async fn send(&mut self, body: impl Message + CmdID) -> Result<()> {
pub async fn consume(&mut self, buffer: &[u8]) -> Result<()> {
{
let mut kcp = self.kcp.lock().await;
kcp.input(buffer)?;
kcp.async_update(self.session_time() as u32).await?;
kcp.async_flush().await?;
}
let mut packets = Vec::new();
let mut buf = [0; 24756];
while let Ok(length) = self.kcp.lock().await.recv(&mut buf) {
packets.push(NetPacket::from(&buf[..length]));
}
for packet in packets {
Self::on_message(self, packet.cmd_type, packet.body).await?;
}
self.kcp
.lock()
.await
.async_update(self.session_time() as u32)
.await?;
Ok(())
}
pub async fn send(&self, body: impl Message + CmdID) -> Result<()> {
let mut buf = Vec::new();
body.encode(&mut buf)?;
@ -32,10 +78,46 @@ impl PlayerSession {
}
.into();
self.client_socket.write_all(&payload).await?;
let mut kcp = self.kcp.lock().await;
kcp.send(&payload)?;
kcp.async_flush().await?;
kcp.async_update(self.session_time() as u32).await?;
Ok(())
}
pub async fn send_raw(&self, payload: NetPacket) -> Result<()> {
let mut kcp = self.kcp.lock().await;
let payload: Vec<u8> = payload.into();
kcp.send(&payload)?;
kcp.async_flush().await?;
kcp.async_update(self.session_time() as u32).await?;
Ok(())
}
fn session_time(&self) -> u64 {
util::cur_timestamp_secs() - self.start_time
}
}
// Auto implemented
impl CommandHandler for PlayerSession {}
impl AsyncWrite for RemoteEndPoint {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, Error>> {
self.socket.poll_send_to(cx, buf, self.addr)
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
Poll::Ready(Ok(()))
}
}

View File

@ -481,6 +481,7 @@ pub struct Position {
}
impl Position {
#[allow(unused)]
pub fn is_empty(&self) -> bool {
self.x == 0 && self.y == 0 && self.z == 0
}
@ -585,6 +586,7 @@ impl From<u32> for MultiPathAvatar {
}
impl MultiPathAvatar {
#[allow(unused)]
pub fn get_gender(&self) -> Gender {
if (*self as u32) < 8000 {
Gender::None

View File

@ -6,3 +6,10 @@ pub fn cur_timestamp_ms() -> u64 {
.unwrap()
.as_millis() as u64
}
pub fn cur_timestamp_secs() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
}

20
kcp/Cargo.toml Normal file
View File

@ -0,0 +1,20 @@
[package]
name = "mhy-kcp"
version.workspace = true
edition = "2021"
[features]
fastack-conserve = []
tokio = ["dep:tokio"]
[dependencies]
bytes = "1.6.0"
log = "0.4.21"
thiserror = "1.0.58"
tokio = { version = "1.37.0", optional = true, features = ["io-util"] }
[dev-dependencies]
time = "0.3.34"
rand = "0.8.5"
env_logger = "0.11.3"

56
kcp/src/error.rs Normal file
View File

@ -0,0 +1,56 @@
use std::{
error::Error as StdError,
io::{self, ErrorKind},
};
/// KCP protocol errors
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("conv inconsistent, expected {0}, found {1}")]
ConvInconsistent(u32, u32),
#[error("invalid mtu {0}")]
InvalidMtu(usize),
#[error("invalid segment size {0}")]
InvalidSegmentSize(usize),
#[error("invalid segment data size, expected {0}, found {1}")]
InvalidSegmentDataSize(usize, usize),
#[error("{0}")]
IoError(
#[from]
#[source]
io::Error,
),
#[error("need to call update() once")]
NeedUpdate,
#[error("recv queue is empty")]
RecvQueueEmpty,
#[error("expecting fragment")]
ExpectingFragment,
#[error("command {0} is not supported")]
UnsupportedCmd(u8),
#[error("user's send buffer is too big")]
UserBufTooBig,
#[error("user's recv buffer is too small")]
UserBufTooSmall,
#[error("token mismatch, expected {0}, found {1}")]
TokenMismatch(u32, u32),
}
fn make_io_error<T>(kind: ErrorKind, msg: T) -> io::Error
where
T: Into<Box<dyn StdError + Send + Sync>>,
{
io::Error::new(kind, msg)
}
impl From<Error> for io::Error {
fn from(err: Error) -> Self {
let kind = match err {
Error::IoError(err) => return err,
Error::RecvQueueEmpty | Error::ExpectingFragment => ErrorKind::WouldBlock,
_ => ErrorKind::Other,
};
make_io_error(kind, err)
}
}

1523
kcp/src/kcp.rs Normal file

File diff suppressed because it is too large Load Diff

17
kcp/src/lib.rs Normal file
View File

@ -0,0 +1,17 @@
extern crate bytes;
#[macro_use]
extern crate log;
mod error;
mod kcp;
/// The `KCP` prelude
pub mod prelude {
pub use super::{get_conv, Kcp, KCP_OVERHEAD};
}
pub use error::Error;
pub use kcp::{get_conv, get_sn, set_conv, Kcp, KCP_OVERHEAD};
/// KCP result
pub type KcpResult<T> = Result<T, Error>;