diff --git a/Cargo.toml b/Cargo.toml index adb0ed7..f80a4f9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,7 @@ tokio = { version = "1.36.0", features = ["full"] } axum = "0.8.1" axum-server = "0.7.1" tower-http = "0.6.2" +reqwest = "0.12.12" # JSON serde = { version = "1.0.197", features = ["derive"] } diff --git a/sdkserver/Cargo.toml b/sdkserver/Cargo.toml index ae7a8e2..1a2d1a6 100644 --- a/sdkserver/Cargo.toml +++ b/sdkserver/Cargo.toml @@ -9,6 +9,7 @@ tokio.workspace = true tower-http = { workspace = true, features = ["cors"] } axum.workspace = true axum-server.workspace = true +reqwest.workspace = true # JSON serde.workspace = true diff --git a/sdkserver/src/config/version_config.rs b/sdkserver/src/config/version_config.rs index 97b500d..2e142f1 100644 --- a/sdkserver/src/config/version_config.rs +++ b/sdkserver/src/config/version_config.rs @@ -1,27 +1,127 @@ -use std::{collections::HashMap, sync::OnceLock}; +use std::collections::HashMap; -use serde::Deserialize; +use anyhow::{Result, anyhow}; +use prost::Message as _; +use proto::Gateserver; +use serde::{Deserialize, Serialize}; +use tokio::fs; const DEFAULT_VERSIONS: &str = include_str!("../../../versions.json"); -#[derive(Deserialize)] +#[derive(Serialize, Deserialize, Clone, Default)] pub struct VersionConfig { pub asset_bundle_url: String, pub ex_resource_url: String, pub lua_url: String, - // pub lua_version: String, + pub ifix_url: String, } -pub fn instance() -> &'static HashMap { - static INSTANCE: OnceLock> = OnceLock::new(); - INSTANCE.get_or_init(|| { +impl VersionConfig { + const CNPROD_HOST: &str = "prod-gf-cn-dp01.bhsr.com"; + const CNBETA_HOST: &str = "beta-release01-cn.bhsr.com"; + const OSPROD_HOST: &str = "prod-official-asia-dp01.starrails.com"; + const OSBETA_HOST: &str = "beta-release01-asia.starrails.com"; + + const PROXY_HOST: &str = "proxy1.neonteam.dev"; // we used this because PS users usually have a proxy enabled + + pub async fn load_hotfix() -> HashMap { const CONFIG_PATH: &str = "versions.json"; - let data = std::fs::read_to_string(CONFIG_PATH).unwrap_or_else(|_| { - std::fs::write(CONFIG_PATH, DEFAULT_VERSIONS).expect("Failed to create versions file"); - DEFAULT_VERSIONS.to_string() - }); + let data = match fs::read_to_string(CONFIG_PATH).await { + Ok(data) => data, + Err(_) => { + fs::write(CONFIG_PATH, DEFAULT_VERSIONS) + .await + .expect("Failed to create versions file"); + DEFAULT_VERSIONS.to_string() + } + }; - serde_json::from_str(&data).unwrap_or_else(|e| panic!("Failed to parse versions.json: {e}")) - }) + match serde_json::from_str(&data) { + Ok(data) => data, + Err(_) => { + tracing::error!("malformed versions.json. replacing it with default one"); + let _ = fs::write(CONFIG_PATH, DEFAULT_VERSIONS).await; + serde_json::from_str(DEFAULT_VERSIONS).unwrap() + } + } + } + + pub async fn fetch_hotfix(version: &str, dispatch_seed: &str) -> Result { + let (region, branch, _, _) = Self::parse_version_string(version)?; + + let host = match (region.as_str(), branch.as_str()) { + ("OS", "BETA") => Self::OSBETA_HOST, + ("OS", "PROD") => Self::OSPROD_HOST, + ("CN", "BETA") => Self::CNBETA_HOST, + ("CN", "PROD") => Self::CNPROD_HOST, + // TODO: Support more host, or use query_dispatch result to determine the urls + _ => Self::OSBETA_HOST, + }; + + let url = format!( + "https://{}/{host}/query_gateway?version={}&platform_type=1&language_type=3&dispatch_seed={}&channel_id=1&sub_channel_id=1&is_need_url=1", + Self::PROXY_HOST, + version, + dispatch_seed + ); + + tracing::info!("fetching hotfix: {url}"); + + let res = reqwest::get(url).await?.text().await?; + + tracing::info!("raw gateway response: {}", res); + + let bytes = rbase64::decode(&res)?; + let decoded = Gateserver::decode(bytes.as_slice())?; + + if decoded.retcode != 0 || res.is_empty() { + return Err(anyhow::format_err!( + "gateway result code: {} message: {}", + decoded.retcode, + decoded.msg + )); + } + + tracing::info!("{:#?}", decoded); + + Ok(VersionConfig { + asset_bundle_url: decoded.asset_bundle_url, + ex_resource_url: decoded.ex_resource_url, + lua_url: decoded.lua_url, + ifix_url: decoded.ifix_url, + }) + } + + fn parse_version_string(s: &str) -> Result<(String, String, String, String)> { + const BRANCHES: [&str; 7] = ["PREbeta", "BETA", "PROD", "DEV", "PRE", "GM", "CECREATION"]; + const OS: [&str; 3] = ["Android", "Win", "iOS"]; + + let region = s + .get(0..2) + .ok_or_else(|| anyhow!("version parse failed. reason: invalid region url: {s}"))?; + let after_region = s.get(2..).unwrap_or(s); + + let branch = BRANCHES + .iter() + .find(|&&b| after_region.starts_with(b)) + .map(|&b| b.to_string()) + .ok_or_else(|| anyhow!("version parse failed. reason: invalid branch url: {s}"))?; + + let branch_len = branch.len(); + let after_branch = after_region.get(branch_len..).unwrap_or(after_region); + + let os = OS + .iter() + .find(|&&o| after_branch.starts_with(o)) + .map(|&o| o.to_string()) + .ok_or_else(|| anyhow!("version parse failed. reason: invalid os url: {s}"))?; + + let os_len = os.len(); + let version = after_branch + .get(os_len..) + .ok_or_else(|| anyhow!("version parse failed. reason: invalid version url: {s}"))?; + + Ok((region.to_string(), branch, os, version.to_string())) + } } diff --git a/sdkserver/src/lib.rs b/sdkserver/src/lib.rs index 1fe687a..0dd39b4 100644 --- a/sdkserver/src/lib.rs +++ b/sdkserver/src/lib.rs @@ -1,9 +1,15 @@ +use std::collections::HashMap; +use std::sync::Arc; + use anyhow::Result; use axum::Router; use axum::http::Method; use axum::http::header::CONTENT_TYPE; use axum::routing::{get, post}; +use config::version_config::VersionConfig; use services::{auth, dispatch, errors, sr_tools}; +use tokio::fs; +use tokio::sync::RwLock; use tower_http::cors::{Any, CorsLayer}; use tracing::Level; @@ -12,9 +18,51 @@ mod services; const PORT: u16 = 21000; +#[derive(Clone)] +struct AppState { + hotfix_map: HashMap, +} + +impl AppState { + async fn get_or_insert_hotfix(&mut self, version: &str, dispatch_seed: &str) -> &VersionConfig { + if self.hotfix_map.contains_key(version) { + return &self.hotfix_map[version]; + } + + tracing::info!( + "trying to fetch hotfix for version {version} with dispatch seed {dispatch_seed}" + ); + + let hotfix = match VersionConfig::fetch_hotfix(version, dispatch_seed).await { + Ok(hotfix) => hotfix, + Err(err) => { + tracing::error!("failed to fetch hotfix. reason: {err}"); + VersionConfig::default() + } + }; + + self.hotfix_map.insert(version.to_string(), hotfix); + + if let Ok(serialized) = serde_json::to_string_pretty(&self.hotfix_map) { + let _ = fs::write("versions.json", serialized).await; + } + + &self.hotfix_map[version] + } +} + pub async fn start_sdkserver() -> Result<()> { let span = tracing::span!(Level::DEBUG, "main"); let _ = span.enter(); + let hotfix_map = VersionConfig::load_hotfix().await; + + tracing::info!( + "loaded {} hotfix versions. supported versions: {:?}", + hotfix_map.len(), + hotfix_map.keys() + ); + + let state = Arc::new(RwLock::new(AppState { hotfix_map })); let router = Router::new() .route( @@ -48,6 +96,7 @@ pub async fn start_sdkserver() -> Result<()> { .allow_methods([Method::GET, Method::POST, Method::PATCH, Method::DELETE]) .allow_headers([CONTENT_TYPE]), ) + .with_state(state) .fallback(errors::not_found); let addr = format!("0.0.0.0:{PORT}"); diff --git a/sdkserver/src/services/dispatch.rs b/sdkserver/src/services/dispatch.rs index a4e4dd1..f086443 100644 --- a/sdkserver/src/services/dispatch.rs +++ b/sdkserver/src/services/dispatch.rs @@ -1,8 +1,12 @@ -use crate::config::version_config; -use axum::extract::Query; +use std::sync::Arc; + +use crate::AppState; +use axum::extract::{Query, State}; use prost::Message; use proto::{DispatchRegionData, Gateserver, RegionEntry}; use serde::Deserialize; +use tokio::sync::RwLock; +use tracing::instrument; pub const QUERY_DISPATCH_ENDPOINT: &str = "/query_dispatch"; pub const QUERY_GATEWAY_ENDPOINT: &str = "/query_gateway"; @@ -30,37 +34,38 @@ pub async fn query_dispatch() -> String { #[derive(Deserialize, Debug)] pub struct QueryGatewayParameters { pub version: String, + pub dispatch_seed: String, } -#[tracing::instrument] -pub async fn query_gateway(parameters: Query) -> String { - let rsp = if let Some(config) = version_config::instance().get(¶meters.version) { - Gateserver { - retcode: 0, - ip: String::from("127.0.0.1"), - port: 23301, - asset_bundle_url: config.asset_bundle_url.clone(), - ex_resource_url: config.ex_resource_url.clone(), - lua_url: config.lua_url.clone(), - ifix_version: String::from("0"), - enable_design_data_version_update: true, - enable_version_update: true, - enable_upload_battle_log: true, - network_diagnostic: true, - close_redeem_code: true, - enable_android_middle_package: true, - enable_watermark: true, - event_tracking_open: true, - enable_cdn_ipv6: 1, - enable_save_replay_file: true, - ..Default::default() - } - } else { - Gateserver { - retcode: 9, - login_white_msg: format!("forbidden version: {} or invalid bind", parameters.version), - ..Default::default() - } +#[instrument(skip(state))] +pub async fn query_gateway( + State(state): State>>, + parameters: Query, +) -> String { + let mut lock = state.write().await; + let config = lock + .get_or_insert_hotfix(¶meters.version, ¶meters.dispatch_seed) + .await; + + let rsp = Gateserver { + retcode: 0, + ip: String::from("127.0.0.1"), + port: 23301, + asset_bundle_url: config.asset_bundle_url.clone(), + ex_resource_url: config.ex_resource_url.clone(), + lua_url: config.lua_url.clone(), + ifix_version: String::from("0"), + enable_design_data_version_update: true, + enable_version_update: true, + enable_upload_battle_log: true, + network_diagnostic: true, + close_redeem_code: true, + enable_android_middle_package: true, + enable_watermark: true, + event_tracking_open: true, + enable_cdn_ipv6: 1, + enable_save_replay_file: true, + ..Default::default() }; let mut buff = Vec::new();