diff --git a/src/download_file.rs b/src/download_file.rs index 741b6fe5..fb338694 100644 --- a/src/download_file.rs +++ b/src/download_file.rs @@ -80,9 +80,9 @@ async fn download_file( pb.set_position(downloaded); write.await?; } - pb.finish_and_clear(); file.flush().await?; file.sync_all().await?; + pb.finish_and_clear(); Ok(DownloadedFile { url: url.into(), diff --git a/src/github/utils.rs b/src/github/utils.rs index ed97cef7..00a1452f 100644 --- a/src/github/utils.rs +++ b/src/github/utils.rs @@ -250,68 +250,74 @@ mod tests { "Package.Identifier.installer.yaml", "Package.Identifier", None, - ManifestType::Installer, - true + ManifestType::Installer )] #[case( - "Package.Identifier.yaml", + "Package.Identifier.locale.en-US.yaml", "Package.Identifier", - None, - ManifestType::Installer, - false + Some("en-US"), + ManifestType::DefaultLocale )] #[case( - "Package.Identifier.locale.en-US.yaml", + "Package.Identifier.locale.zh-CN.yaml", "Package.Identifier", Some("en-US"), - ManifestType::DefaultLocale, - true + ManifestType::Locale )] #[case( - "Package.Identifier.locale.en-US.yaml", + "Package.Identifier.yaml", "Package.Identifier", - Some("zh-CN"), - ManifestType::DefaultLocale, - false + None, + ManifestType::Version )] + fn valid_manifest_files( + #[case] file_name: &str, + #[case] identifier: &str, + #[case] default_locale: Option<&str>, + #[case] manifest_type: ManifestType, + ) { + let identifier = PackageIdentifier::parse(identifier).unwrap(); + let default_locale = default_locale.and_then(|locale| LanguageTag::from_str(locale).ok()); + assert!(is_manifest_file( + file_name, + &identifier, + default_locale.as_ref(), + &manifest_type + )) + } + + #[rstest] #[case( - "Package.Identifier.locale.zh-CN.yaml", + "Package.Identifier.yaml", "Package.Identifier", - Some("en-US"), - ManifestType::Locale, - true + None, + ManifestType::Installer )] #[case( "Package.Identifier.locale.en-US.yaml", "Package.Identifier", - Some("en-US"), - ManifestType::Locale, - false + Some("zh-CN"), + ManifestType::DefaultLocale )] #[case( - "Package.Identifier.yaml", + "Package.Identifier.locale.en-US.yaml", "Package.Identifier", - None, - ManifestType::Version, - true + Some("en-US"), + ManifestType::Locale )] - fn is_manifest_files( + fn invalid_manifest_files( #[case] file_name: &str, #[case] identifier: &str, #[case] default_locale: Option<&str>, #[case] manifest_type: ManifestType, - #[case] expected: bool, ) { let identifier = PackageIdentifier::parse(identifier).unwrap(); let default_locale = default_locale.and_then(|locale| LanguageTag::from_str(locale).ok()); - assert_eq!( - is_manifest_file( - file_name, - &identifier, - default_locale.as_ref(), - &manifest_type - ), - expected - ) + assert!(!is_manifest_file( + file_name, + &identifier, + default_locale.as_ref(), + &manifest_type + )) } } diff --git a/src/types/minimum_os_version.rs b/src/types/minimum_os_version.rs index f42e4c7a..b5f75dcd 100644 --- a/src/types/minimum_os_version.rs +++ b/src/types/minimum_os_version.rs @@ -1,37 +1,74 @@ -use color_eyre::eyre::Error; -use derive_more::{Display, FromStr}; -use serde::{de, Deserialize, Deserializer, Serialize, Serializer}; -use versions::Version; +use color_eyre::eyre::OptionExt; +use derive_more::Display; +use serde_with::{DeserializeFromStr, SerializeDisplay}; +use std::str::FromStr; -#[derive(Clone, Debug, Default, Display, Eq, FromStr, Hash, Ord, PartialEq, PartialOrd)] -pub struct MinimumOSVersion(Version); +#[derive( + SerializeDisplay, + DeserializeFromStr, + Copy, + Clone, + Debug, + Default, + Display, + Eq, + PartialEq, + Hash, + Ord, + PartialOrd, +)] +#[display("{_0}.{_1}.{_2}.{_3}")] +pub struct MinimumOSVersion(u16, u16, u16, u16); -impl MinimumOSVersion { - pub fn new(input: &str) -> color_eyre::Result { - Ok(Self(Version::from_str(input).map_err(Error::msg)?)) - } +impl FromStr for MinimumOSVersion { + type Err = color_eyre::Report; - pub fn removable() -> Self { - Self::new("10.0.0.0").unwrap() + fn from_str(s: &str) -> Result { + let mut parts = s.splitn(Self::MAX_PARTS as usize, Self::SEPARATOR); + let major = parts + .next() + .ok_or_eyre("No major version")? + .parse::()?; + let minor = parts.next().map_or(Ok(0), u16::from_str)?; + let patch = parts.next().map_or(Ok(0), u16::from_str)?; + let build = parts.next().map_or(Ok(0), u16::from_str)?; + Ok(Self(major, minor, patch, build)) } } -impl Serialize for MinimumOSVersion { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - serializer.serialize_str(&self.0.to_string()) +impl MinimumOSVersion { + const MAX_PARTS: u8 = 4; + const SEPARATOR: char = '.'; + + pub const fn removable() -> Self { + Self(10, 0, 0, 0) } } -impl<'de> Deserialize<'de> for MinimumOSVersion { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - String::deserialize(deserializer)? - .parse() - .map_err(de::Error::custom) +#[cfg(test)] +mod tests { + use crate::types::minimum_os_version::MinimumOSVersion; + use rstest::rstest; + use std::str::FromStr; + + #[rstest] + #[case("10.0.17763.0", MinimumOSVersion(10, 0, 17763, 0))] + #[case("11", MinimumOSVersion(11, 0, 0, 0))] + #[case("10.1", MinimumOSVersion(10, 1, 0, 0))] + #[case("0", MinimumOSVersion(0, 0, 0, 0))] + #[case( + "65535.65535.65535.65535", + MinimumOSVersion(u16::MAX, u16::MAX, u16::MAX, u16::MAX) + )] + fn from_str(#[case] minimum_os_version: &str, #[case] expected: MinimumOSVersion) { + assert_eq!( + MinimumOSVersion::from_str(minimum_os_version).unwrap(), + expected + ) + } + + #[test] + fn display() { + assert_eq!(MinimumOSVersion(1, 2, 3, 4).to_string(), "1.2.3.4") } } diff --git a/src/types/package_version.rs b/src/types/package_version.rs index cbfac04f..fababa16 100644 --- a/src/types/package_version.rs +++ b/src/types/package_version.rs @@ -1,43 +1,33 @@ use crate::prompts::prompt::RequiredPrompt; -use color_eyre::eyre::{Error, Result}; use derive_more::{Deref, Display}; -use serde::{de, Deserialize, Deserializer, Serialize, Serializer}; +use serde_with::{DeserializeFromStr, SerializeDisplay}; use std::str::FromStr; use versions::Versioning; -#[derive(Clone, Default, Deref, Display, Eq, Ord, PartialEq, PartialOrd)] +#[derive( + SerializeDisplay, + DeserializeFromStr, + Clone, + Default, + Deref, + Display, + Eq, + Ord, + PartialEq, + PartialOrd, +)] pub struct PackageVersion(Versioning); impl PackageVersion { - pub fn new(input: &str) -> Result { - Ok(Self(Versioning::from_str(input).map_err(Error::msg)?)) - } -} - -impl Serialize for PackageVersion { - fn serialize(&self, serializer: S) -> std::result::Result - where - S: Serializer, - { - serializer.serialize_str(&self.0.to_string()) - } -} - -impl<'de> Deserialize<'de> for PackageVersion { - fn deserialize(deserializer: D) -> std::result::Result - where - D: Deserializer<'de>, - { - String::deserialize(deserializer)? - .parse() - .map_err(de::Error::custom) + pub fn new(input: &str) -> Result { + Ok(Self(Versioning::from_str(input)?)) } } impl FromStr for PackageVersion { - type Err = Error; + type Err = versions::Error; - fn from_str(input: &str) -> Result { + fn from_str(input: &str) -> Result { Self::new(input) } }