Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,15 @@ http = "1"
reqwest = { version = "0.12", default-features = false }
once_cell = "1"
tokio = { version = "1", features = ["macros"] }
flate2 = "1.1.2"
brotli = "8.0.2"

[build-dependencies]
tauri-plugin = { version = "2", features = ["build"] }

[dev-dependencies]
httpmock = "0.6"

[features]
default = [
"rustls-tls",
Expand Down
117 changes: 117 additions & 0 deletions src/command_test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
#[cfg(test)]
mod tests {
use crate::commands::{get_response, build_request, RequestConfig};
use url::Url;
use tokio::sync::oneshot;
use flate2::{write::GzEncoder, write::DeflateEncoder, Compression};
use brotli::CompressorWriter;
use std::io::Write;

fn encode_gzip(data: &[u8]) -> Vec<u8> {
let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
encoder.write_all(data).unwrap();
encoder.finish().unwrap()
}

fn encode_deflate(data: &[u8]) -> Vec<u8> {
let mut encoder = DeflateEncoder::new(Vec::new(), Compression::default());
encoder.write_all(data).unwrap();
encoder.finish().unwrap()
}

fn encode_brotli(data: &[u8]) -> Vec<u8> {
let mut encoder = CompressorWriter::new(Vec::new(), 4096, 5, 22);
encoder.write_all(data).unwrap();
encoder.into_inner()
}

#[tokio::test]
async fn test_get_response_gzip() {
use httpmock::MockServer;
let server = MockServer::start_async().await;
let data = b"hello gzip";
let encoded = encode_gzip(data);
let mock = server.mock_async(|when, then| {
when.method("GET").path("/test_gzip");
then.status(200)
.header("content-encoding", "gzip")
.body(encoded.clone());
}).await;
let url = Url::parse(&format!("{}test_gzip", server.url("/"))).unwrap();
let request_config = RequestConfig::new(
1,
"GET".to_string(),
url,
vec![],
None,
None,
None,
None,
);
let request = build_request(request_config).unwrap();
let (_tx, rx) = oneshot::channel();
let response = get_response(request, rx).await.unwrap();
assert_eq!(response.body().as_ref().unwrap(), data);
mock.assert_async().await;
}

#[tokio::test]
async fn test_get_response_deflate() {
use httpmock::MockServer;
let server = MockServer::start_async().await;
let data = b"hello deflate";
let encoded = encode_deflate(data);
let mock = server.mock_async(|when, then| {
when.method("GET").path("/test_deflate");
then.status(200)
.header("content-encoding", "deflate")
.body(encoded.clone());
}).await;
let url = Url::parse(&format!("{}test_deflate", server.url("/"))).unwrap();
let request_config = RequestConfig::new(
2,
"GET".to_string(),
url,
vec![],
None,
None,
None,
None,
);
let request = build_request(request_config).unwrap();
let (_tx, rx) = oneshot::channel();
let response = get_response(request, rx).await.unwrap();
assert_eq!(response.body().as_ref().unwrap(), data);
mock.assert_async().await;
}

#[tokio::test]
async fn test_get_response_brotli() {
use httpmock::MockServer;
let server = MockServer::start_async().await;
let data = b"hello brotli";
let encoded = encode_brotli(data);
let mock = server.mock_async(|when, then| {
when.method("GET").path("/test_brotli");
then.status(200)
.header("content-encoding", "br")
.body(encoded.clone());
}).await;
let url = Url::parse(&format!("{}test_brotli", server.url("/"))).unwrap();
let request_config = RequestConfig::new(
3,
"GET".to_string(),
url,
vec![],
None,
None,
None,
None,
);
let request = build_request(request_config).unwrap();
let (_tx, rx) = oneshot::channel();
let response = get_response(request, rx).await.unwrap();
assert_eq!(response.body().as_ref().unwrap(), data);
mock.assert_async().await;
}
}
78 changes: 73 additions & 5 deletions src/commands.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,31 @@ pub struct RequestConfig {
proxy: Option<Proxy>,
}

#[cfg(test)]
impl RequestConfig {
pub(crate) fn new(
request_id: u64,
method: String,
url: url::Url,
headers: Vec<(String, String)>,
data: Option<Vec<u8>>,
connect_timeout: Option<u64>,
max_redirections: Option<usize>,
proxy: Option<Proxy>,
) -> Self {
Self {
request_id,
method,
url,
headers,
data,
connect_timeout,
max_redirections,
proxy,
}
}
}

#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
pub struct FetchResponse {
Expand All @@ -36,6 +61,13 @@ pub struct FetchResponse {
body: Option<Vec<u8>>,
}

#[cfg(test)]
impl FetchResponse {
pub(crate) fn body(&self) -> &Option<Vec<u8>> {
&self.body
}
}

use once_cell::sync::Lazy;
use tokio::sync::oneshot;
type RequestPool = Arc<std::sync::Mutex<HashMap<u64, oneshot::Sender<()>>>>;
Expand Down Expand Up @@ -141,18 +173,54 @@ pub async fn get_response(
let status = res.status();
let url = res.url().to_string();
let mut headers = Vec::new();
let mut content_encoding = None;
for (key, val) in res.headers().iter() {
headers.push((
key.as_str().into(),
String::from_utf8(val.as_bytes().to_vec())?,
));
let key_str: String = key.as_str().into();
let val_str = String::from_utf8(val.as_bytes().to_vec())?;
if key_str.eq_ignore_ascii_case("content-encoding") {
content_encoding = Some(val_str.to_lowercase());
}
headers.push((key_str, val_str));
}
let bytes = res.bytes().await?;
let body = if let Some(enc) = content_encoding {
match enc.as_str() {
"gzip" => {
use flate2::read::GzDecoder;
use std::io::Read;
let mut d = GzDecoder::new(&bytes[..]);
let mut decompressed = Vec::new();
d.read_to_end(&mut decompressed)?;
decompressed
},
"deflate" => {
use flate2::read::DeflateDecoder;
use std::io::Read;
let mut d = DeflateDecoder::new(&bytes[..]);
let mut decompressed = Vec::new();
d.read_to_end(&mut decompressed)?;
decompressed
},
"br" => {
use brotli::Decompressor;
use std::io::Read;
let mut d = Decompressor::new(&bytes[..], 4096);
let mut decompressed = Vec::new();
d.read_to_end(&mut decompressed)?;
decompressed
},
"identity" => bytes.to_vec(),
_ => bytes.to_vec(),
}
} else {
bytes.to_vec()
};
return Ok(FetchResponse {
status: status.as_u16(),
status_text: status.canonical_reason().unwrap_or_default().to_string(),
headers,
url,
body: Some(res.bytes().await?.to_vec()),
body: Some(body),
});
}
Err(err) => return Err(Error::Network(err)),
Expand Down
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ use tauri::{
pub use error::{Error, Result};
mod commands;
mod error;
#[cfg(test)]
mod command_test;

pub fn init<R: Runtime>() -> TauriPlugin<R> {
Builder::<R>::new("cors-fetch")
Expand Down