From 173553f8c5f1c88033ddf61745839eca0e8a776e Mon Sep 17 00:00:00 2001 From: Kante Yin Date: Sun, 19 Apr 2026 14:41:40 +0100 Subject: [PATCH 01/11] Support HF downloading models (#16) * Add HF downloader support Signed-off-by: kerthcet * add bars Signed-off-by: kerthcet * fix color Signed-off-by: kerthcet * fix color Signed-off-by: kerthcet * add download successfully message Signed-off-by: kerthcet * change the color Signed-off-by: kerthcet * change the rending shape Signed-off-by: kerthcet --------- Signed-off-by: kerthcet --- Cargo.lock | 483 +++++++++++++++++++++++++++++++++- Cargo.toml | 2 + Makefile | 5 +- README.md | 48 +++- src/cli/commands.rs | 26 +- src/downloader/downloader.rs | 19 +- src/downloader/huggingface.rs | 221 ++++++++++++++++ src/downloader/mod.rs | 3 +- src/main.rs | 3 +- 9 files changed, 781 insertions(+), 29 deletions(-) create mode 100644 src/downloader/huggingface.rs diff --git a/Cargo.lock b/Cargo.lock index 522f786..99f9700 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -109,6 +109,12 @@ version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" +[[package]] +name = "base64ct" +version = "1.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2af50177e190e07a26ab74f8b1efbfe2ef87da2116221318cb1c2e82baf7de06" + [[package]] name = "bitflags" version = "2.8.0" @@ -121,6 +127,12 @@ version = "3.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1628fb46dfa0b37568d12e5edd512553eccf6a22a78e8bde00bb4aed84d5bdbf" +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + [[package]] name = "bytes" version = "1.10.0" @@ -188,6 +200,16 @@ version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990" +[[package]] +name = "colored" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "117725a109d387c937a1533ce01b450cbde6b88abceea8473c4d7a85853cda3c" +dependencies = [ + "lazy_static", + "windows-sys 0.59.0", +] + [[package]] name = "console" version = "0.15.10" @@ -201,6 +223,47 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "console" +version = "0.16.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d64e8af5551369d19cf50138de61f1c42074ab970f74e99be916646777f8fc87" +dependencies = [ + "encode_unicode", + "libc", + "unicode-width 0.2.0", + "windows-sys 0.61.2", +] + +[[package]] +name = "cookie" +version = "0.18.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ddef33a339a91ea89fb53151bd0a4689cfce27055c291dfa69945475d22c747" +dependencies = [ + "percent-encoding", + "time", + "version_check", +] + +[[package]] +name = "cookie_store" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15b2c103cf610ec6cae3da84a766285b42fd16aad564758459e6ecf128c75206" +dependencies = [ + "cookie", + "document-features", + "idna", + "indexmap", + "log", + "serde", + "serde_derive", + "serde_json", + "time", + "url", +] + [[package]] name = "core-foundation" version = "0.9.4" @@ -217,6 +280,15 @@ version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" +[[package]] +name = "crc32fast" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9481c1c90cbf2ac953f07c8d4a58aa3945c425b7185c9154d67a65e4230da511" +dependencies = [ + "cfg-if", +] + [[package]] name = "csv" version = "1.3.1" @@ -238,6 +310,25 @@ dependencies = [ "memchr", ] +[[package]] +name = "der" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "71fd89660b2dc699704064e59e9dba0147b903e85319429e131620d022be411b" +dependencies = [ + "pem-rfc7468", + "zeroize", +] + +[[package]] +name = "deranged" +version = "0.5.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7cd812cc2bc1d69d4764bd80df88b4317eaef9e773c75226407d9bc0876b211c" +dependencies = [ + "powerfmt", +] + [[package]] name = "dirs" version = "6.0.0" @@ -266,7 +357,7 @@ dependencies = [ "libc", "option-ext", "redox_users 0.5.0", - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] @@ -291,6 +382,15 @@ dependencies = [ "syn", ] +[[package]] +name = "document-features" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d4b8a88685455ed29a21542a33abd9cb6510b6b129abadabdcef0f4c55bc8f61" +dependencies = [ + "litrs", +] + [[package]] name = "encode_unicode" version = "1.0.0" @@ -351,6 +451,16 @@ version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" +[[package]] +name = "flate2" +version = "1.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "843fba2746e448b37e26a819579957415c8cef339bf08564fe8b7ddbd959573c" +dependencies = [ + "crc32fast", + "miniz_oxide", +] + [[package]] name = "fnv" version = "1.0.7" @@ -381,6 +491,21 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "futures" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + [[package]] name = "futures-channel" version = "0.3.31" @@ -388,6 +513,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" dependencies = [ "futures-core", + "futures-sink", ] [[package]] @@ -396,6 +522,34 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" +[[package]] +name = "futures-executor" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718" + +[[package]] +name = "futures-macro" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "futures-sink" version = "0.3.31" @@ -414,10 +568,16 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" dependencies = [ + "futures-channel", "futures-core", + "futures-io", + "futures-macro", + "futures-sink", "futures-task", + "memchr", "pin-project-lite", "pin-utils", + "slab", ] [[package]] @@ -486,6 +646,36 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fbf6a919d6cf397374f7dfeeea91d974c7c0a7221d0d0f4f20d859d329e53fcc" +[[package]] +name = "hermit-abi" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c" + +[[package]] +name = "hf-hub" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aef3982638978efa195ff11b305f51f1f22f4f0a6cabee7af79b383ebee6a213" +dependencies = [ + "dirs", + "futures", + "http", + "indicatif 0.18.4", + "libc", + "log", + "native-tls", + "num_cpus", + "rand", + "reqwest", + "serde", + "serde_json", + "thiserror 2.0.11", + "tokio", + "ureq", + "windows-sys 0.61.2", +] + [[package]] name = "http" version = "1.2.0" @@ -759,13 +949,26 @@ version = "0.17.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "183b3088984b400f4cfac3620d5e076c84da5364016b4f49473de574b2586235" dependencies = [ - "console", + "console 0.15.10", "number_prefix", "portable-atomic", "unicode-width 0.2.0", "web-time", ] +[[package]] +name = "indicatif" +version = "0.18.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25470f23803092da7d239834776d653104d551bc4d7eacaf31e6837854b8e9eb" +dependencies = [ + "console 0.16.3", + "portable-atomic", + "unicode-width 0.2.0", + "unit-prefix", + "web-time", +] + [[package]] name = "ipnet" version = "2.11.0" @@ -778,7 +981,7 @@ version = "0.4.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e19b23d53f35ce9f56aebc7d1bb4e6ac1e9c0db7ac85c8d1760c04379edced37" dependencies = [ - "hermit-abi", + "hermit-abi 0.4.0", "libc", "windows-sys 0.59.0", ] @@ -839,6 +1042,12 @@ version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4ee93343901ab17bd981295f2cf0026d4ad018c7c31ba84549a4ddbb47a45104" +[[package]] +name = "litrs" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11d3d7f243d5c5a8b9bb5d6dd2b1602c0cb0b9db1621bafc7ed66e35ff9fe092" + [[package]] name = "lock_api" version = "0.4.12" @@ -874,6 +1083,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8e3e04debbb59698c15bacbb6d93584a8c0ca9cc3213cb423d31f760d8843ce5" dependencies = [ "adler2", + "simd-adler32", ] [[package]] @@ -904,6 +1114,22 @@ dependencies = [ "tempfile", ] +[[package]] +name = "num-conv" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6673768db2d862beb9b39a78fdcb1a69439615d5794a1be50caa9bc92c81967" + +[[package]] +name = "num_cpus" +version = "1.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91df4bbde75afed763b708b7eee1e8e7651e02d97f6d5dd763e89367e957b23b" +dependencies = [ + "hermit-abi 0.5.2", + "libc", +] + [[package]] name = "number_prefix" version = "0.4.0" @@ -998,6 +1224,15 @@ dependencies = [ "windows-targets", ] +[[package]] +name = "pem-rfc7468" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6305423e0e7738146434843d1694d621cce767262b2a86910beab705e4493d9" +dependencies = [ + "base64ct", +] + [[package]] name = "percent-encoding" version = "2.3.1" @@ -1024,9 +1259,24 @@ checksum = "953ec861398dccce10c670dfeaf3ec4911ca479e9c02154b3a215178c5f566f2" [[package]] name = "portable-atomic" -version = "1.10.0" +version = "1.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "280dc24453071f1b63954171985a0b0d30058d287960968b9b2aca264c8d4ee6" +checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49" + +[[package]] +name = "powerfmt" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" + +[[package]] +name = "ppv-lite86" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" +dependencies = [ + "zerocopy", +] [[package]] name = "prettytable-rs" @@ -1056,9 +1306,11 @@ name = "puma" version = "0.0.1" dependencies = [ "clap", + "colored", "dirs", "env_logger", - "indicatif", + "hf-hub", + "indicatif 0.17.11", "log", "prettytable-rs", "reqwest", @@ -1076,6 +1328,35 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "rand" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44c5af06bb1b7d3216d91932aed5265164bf384dc89cd6ba05cf59a35f5f76ea" +dependencies = [ + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76afc826de14238e6e8c374ddcc1fa19e374fd8dd986b0d2af0d02377261d83c" +dependencies = [ + "getrandom 0.3.1", +] + [[package]] name = "redox_syscall" version = "0.5.9" @@ -1171,11 +1452,13 @@ dependencies = [ "system-configuration", "tokio", "tokio-native-tls", + "tokio-util", "tower", "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", + "wasm-streams", "web-sys", "windows-registry", ] @@ -1219,7 +1502,9 @@ version = "0.23.23" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "47796c98c480fce5406ef69d1c76378375492c3b0a0de587be0c1d9feb12f395" dependencies = [ + "log", "once_cell", + "ring", "rustls-pki-types", "rustls-webpki", "subtle", @@ -1304,18 +1589,28 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.218" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_core" +version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8dfc9d19bdbf6d17e22319da49161d5d0108e4188e8b680aef6299eed22df60" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.218" +version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f09503e191f4e797cb8aac08e9a4a4695c5edf6a2e70e376d961ddd5c969f82b" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" dependencies = [ "proc-macro2", "quote", @@ -1361,6 +1656,12 @@ dependencies = [ "libc", ] +[[package]] +name = "simd-adler32" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "703d5c7ef118737c72f1af64ad2f6f8c5e1921f818cdcb97b8fe6fc69bf66214" + [[package]] name = "slab" version = "0.4.9" @@ -1386,6 +1687,17 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "socks" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0c3dbbd9ae980613c6dd8e28a9407b50509d3803b57624d5dfe8315218cd58b" +dependencies = [ + "byteorder", + "libc", + "winapi", +] + [[package]] name = "stable_deref_trait" version = "1.2.0" @@ -1521,6 +1833,37 @@ dependencies = [ "syn", ] +[[package]] +name = "time" +version = "0.3.47" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "743bd48c283afc0388f9b8827b976905fb217ad9e647fae3a379a9283c4def2c" +dependencies = [ + "deranged", + "itoa", + "num-conv", + "powerfmt", + "serde_core", + "time-core", + "time-macros", +] + +[[package]] +name = "time-core" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7694e1cfe791f8d31026952abf09c69ca6f6fa4e1a1229e18988f06a04a12dca" + +[[package]] +name = "time-macros" +version = "0.2.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e70e4c5a0e0a8a4823ad65dfe1a6930e4f4d756dcd9dd7939022b5e8c501215" +dependencies = [ + "num-conv", + "time-core", +] + [[package]] name = "tinystr" version = "0.7.6" @@ -1663,12 +2006,54 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fc81956842c57dac11422a97c3b8195a1ff727f06e85c84ed2e8aa277c9a0fd" +[[package]] +name = "unit-prefix" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81e544489bf3d8ef66c953931f56617f423cd4b5494be343d9b9d3dda037b9a3" + [[package]] name = "untrusted" version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" +[[package]] +name = "ureq" +version = "3.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dea7109cdcd5864d4eeb1b58a1648dc9bf520360d7af16ec26d0a9354bafcfc0" +dependencies = [ + "base64", + "cookie_store", + "der", + "flate2", + "log", + "native-tls", + "percent-encoding", + "rustls", + "rustls-pki-types", + "serde", + "serde_json", + "socks", + "ureq-proto", + "utf8-zero", + "webpki-root-certs", + "webpki-roots", +] + +[[package]] +name = "ureq-proto" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e994ba84b0bd1b1b0cf92878b7ef898a5c1760108fe7b6010327e274917a808c" +dependencies = [ + "base64", + "http", + "httparse", + "log", +] + [[package]] name = "url" version = "2.5.4" @@ -1686,6 +2071,12 @@ version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c8232dd3cdaed5356e0f716d285e4b40b932ac434100fe9b7e0e8e935b9e6246" +[[package]] +name = "utf8-zero" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8c0a043c9540bae7c578c88f91dda8bd82e59ae27c21baca69c8b191aaf5a6e" + [[package]] name = "utf8_iter" version = "1.0.4" @@ -1704,6 +2095,12 @@ version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" +[[package]] +name = "version_check" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" + [[package]] name = "want" version = "0.3.1" @@ -1799,6 +2196,19 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "wasm-streams" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65" +dependencies = [ + "futures-util", + "js-sys", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + [[package]] name = "web-sys" version = "0.3.77" @@ -1819,6 +2229,24 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "webpki-root-certs" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31141ce3fc3e300ae89b78c0dd67f9708061d1d2eda54b8209346fd6be9a92c" +dependencies = [ + "rustls-pki-types", +] + +[[package]] +name = "webpki-roots" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52f5ee44c96cf55f1b349600768e3ece3a8f26010c05265ab73f945bb1a2eb9d" +dependencies = [ + "rustls-pki-types", +] + [[package]] name = "winapi" version = "0.3.9" @@ -1841,6 +2269,12 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" + [[package]] name = "windows-registry" version = "0.2.0" @@ -1889,6 +2323,15 @@ dependencies = [ "windows-targets", ] +[[package]] +name = "windows-sys" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" +dependencies = [ + "windows-link", +] + [[package]] name = "windows-targets" version = "0.52.6" @@ -1998,6 +2441,26 @@ dependencies = [ "synstructure", ] +[[package]] +name = "zerocopy" +version = "0.8.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0894878a5fa3edfd6da3f88c4805f4c8558e2b996227a3d864f47fe11e38282c" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.8.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88d2b8d9c68ad2b9e4340d7832716a4d21a22a1154777ad56ea55c51a9cf3831" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "zerofrom" version = "0.1.5" diff --git a/Cargo.toml b/Cargo.toml index ef37a03..f79acd5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,3 +17,5 @@ env_logger = "0.11.6" log = "0.4.26" indicatif = "0.17.11" dirs = "6.0.0" +hf-hub = { version = "0.5.0", features = ["tokio"] } +colored = "2.1" diff --git a/Makefile b/Makefile index 36f1dd8..4fad7b1 100644 --- a/Makefile +++ b/Makefile @@ -1,2 +1,5 @@ build: - cargo build && cp target/debug/puma ./puma \ No newline at end of file + cargo build && cp target/debug/puma ./puma + +test: + cargo test diff --git a/README.md b/README.md index c9317d9..3adf683 100644 --- a/README.md +++ b/README.md @@ -1,19 +1,49 @@ # PUMA -**PUMA** aims to be a lightweight, high-performance inference engine for local AI. Play for fun. +**PUMA** aims to be a lightweight, high-performance inference engine for local AI. *Play for fun.* -## How to Run +## Features -### Build +- πŸš€ **Model Management** - Download and manage AI models from multiple providers + +## Quick Start + +### Installation + +```bash +make build +``` -Run `make build` to build the **puma** binary. +## Commands + +| Command | Description | +|---------|-------------| +| `pull` | Download a model from a provider | +| `ls` | List local models | +| `ps` | List running models | +| `run` | Create and run a model | +| `stop` | Stop a running model | +| `rm` | Remove a model | +| `info` | Display system-wide information | +| `inspect` | Return detailed information about a model | +| `version` | Show PUMA version | +| `help` | Show help information | + +## Development + +### Build -### Run +```bash +make compile +``` -Run `./puma help` to see all available commands. +### Test -For example, you can run `./puma version` to see the binary version. +```bash +make test +``` -## Supported Backends +### Supported Providers -Use [llama.cpp](https://github.com/ggerganov/llama.cpp) as the default backend for quick prototyping, will implement our own backend in the future. +- βœ… **Hugging Face** - Full support with custom cache directories +- 🚧 **ModelScope** - Coming soon diff --git a/src/cli/commands.rs b/src/cli/commands.rs index 7e9a6a7..6e86627 100644 --- a/src/cli/commands.rs +++ b/src/cli/commands.rs @@ -1,5 +1,9 @@ use clap::{Parser, Subcommand}; use prettytable::{format, row, Table}; +use std::path::PathBuf; + +use crate::downloader::downloader::Downloader; +use crate::downloader::huggingface::HuggingFaceDownloader; #[derive(Parser)] #[command(name = "PUMA")] @@ -33,10 +37,18 @@ enum Commands { #[derive(Parser)] struct PullArgs { - #[arg(long, value_name = "model name")] + #[arg(short = 'm', long, value_name = "model name")] model: String, - #[arg(long, value_name = "model provider", value_enum)] + #[arg( + short = 'p', + long, + value_name = "model provider", + value_enum, + default_value = "huggingface" + )] provider: Provider, + #[arg(long, value_name = "cache directory")] + cache_dir: Option, } #[derive(Debug, Clone, clap::ValueEnum)] @@ -85,7 +97,15 @@ pub async fn run(cli: Cli) { Commands::PULL(args) => match args.provider { Provider::Huggingface => { - println!("Downloading model from Huggingface..."); + let downloader = HuggingFaceDownloader::new(); + let cache_dir = args.cache_dir.unwrap_or_else(|| PathBuf::new()); + match downloader.download_model(&args.model, &cache_dir).await { + Ok(_) => {} + Err(e) => { + eprintln!("Error downloading model: {}", e); + std::process::exit(1); + } + } } Provider::Modelscope => { println!("Downloading model from Modelscope..."); diff --git a/src/downloader/downloader.rs b/src/downloader/downloader.rs index 5320258..4b0e59e 100644 --- a/src/downloader/downloader.rs +++ b/src/downloader/downloader.rs @@ -1,9 +1,13 @@ use core::fmt; +use std::path::PathBuf; #[derive(Debug)] pub enum DownloadError { - RequestError(String), - ParseError(String), + NetworkError(String), + AuthError(String), + ModelNotFound(String), + IoError(String), + ApiError(String), } impl std::error::Error for DownloadError {} @@ -11,8 +15,15 @@ impl std::error::Error for DownloadError {} impl fmt::Display for DownloadError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - DownloadError::RequestError(e) => write!(f, "RequestError: {}", e), - DownloadError::ParseError(e) => write!(f, "ParseError: {}", e), + DownloadError::NetworkError(e) => write!(f, "Network error: {}", e), + DownloadError::AuthError(e) => write!(f, "Authentication error: {}", e), + DownloadError::ModelNotFound(e) => write!(f, "Model not found: {}", e), + DownloadError::IoError(e) => write!(f, "IO error: {}", e), + DownloadError::ApiError(e) => write!(f, "API error: {}", e), } } } + +pub trait Downloader { + async fn download_model(&self, name: &str, cache_dir: &PathBuf) -> Result<(), DownloadError>; +} diff --git a/src/downloader/huggingface.rs b/src/downloader/huggingface.rs new file mode 100644 index 0000000..d046b6b --- /dev/null +++ b/src/downloader/huggingface.rs @@ -0,0 +1,221 @@ +use colored::Colorize; +use log::{debug, info}; +use std::path::PathBuf; +use std::sync::Arc; + +use hf_hub::api::tokio::{ApiBuilder, Progress}; +use indicatif::{MultiProgress, ProgressBar, ProgressStyle}; + +use crate::downloader::downloader::{DownloadError, Downloader}; + +#[derive(Clone)] +struct FileProgressBar { + pb: ProgressBar, +} + +impl Progress for FileProgressBar { + async fn init(&mut self, size: usize, _filename: &str) { + self.pb.set_length(size as u64); + self.pb.reset(); + self.pb.tick(); // Force render with correct size + } + + async fn update(&mut self, size: usize) { + self.pb.inc(size as u64); + } + + async fn finish(&mut self) {} +} + +pub struct HuggingFaceDownloader; + +impl HuggingFaceDownloader { + pub fn new() -> Self { + Self + } +} + +impl Default for HuggingFaceDownloader { + fn default() -> Self { + Self::new() + } +} + +impl Downloader for HuggingFaceDownloader { + async fn download_model(&self, name: &str, cache_dir: &PathBuf) -> Result<(), DownloadError> { + let start_time = std::time::Instant::now(); + + info!("Downloading model {} from Hugging Face...", name); + + // Build API without default progress bars (we have our own implementation) + let api = if cache_dir.as_os_str().is_empty() { + // Use default HF cache + ApiBuilder::new().build().map_err(|e| { + DownloadError::ApiError(format!("Failed to initialize Hugging Face API: {}", e)) + })? + } else { + // Use custom cache directory + ApiBuilder::new() + .with_cache_dir(cache_dir.clone()) + .build() + .map_err(|e| { + DownloadError::ApiError(format!( + "Failed to initialize Hugging Face API with custom cache: {}", + e + )) + })? + }; + + // Download the entire model repository using snapshot download + let repo = api.model(name.to_string()); + + // Get model info to list all files + let model_info = repo.info().await.map_err(|e| { + let err_str = e.to_string(); + if err_str.contains("404") || err_str.contains("not found") { + DownloadError::ModelNotFound(format!("Model '{}' not found", name)) + } else if err_str.contains("401") || err_str.contains("403") { + DownloadError::AuthError(format!("Authentication failed: {}", e)) + } else if err_str.contains("network") || err_str.contains("connection") { + DownloadError::NetworkError(format!("Network error: {}", e)) + } else { + DownloadError::ApiError(format!("Failed to fetch model info: {}", e)) + } + })?; + + debug!("Model info for {}: {:?}", name, model_info); + + // Create multi-progress for parallel downloads + let multi_progress = Arc::new(MultiProgress::new()); + + // Progress bar style with block characters (chart-like, not #) + let style = ProgressStyle::default_bar() + .template("{msg:<30} [{elapsed_precise}] {bar:60.white} {bytes}/{total_bytes}") + .unwrap() + .progress_chars("▇▆▅▄▃▂▁ "); + + // Download all files in parallel + let mut tasks = Vec::new(); + + for sibling in model_info.siblings { + let api_clone = api.clone(); + let model_name = name.to_string(); + let filename = sibling.rfilename.clone(); + + let pb = multi_progress.add(ProgressBar::hidden()); + pb.set_style(style.clone()); + pb.set_message(filename.clone()); + + let task = tokio::spawn(async move { + debug!("Downloading: {}", filename); + + let repo = api_clone.model(model_name); + let progress = FileProgressBar { pb: pb.clone() }; + + let result = repo.download_with_progress(&filename, progress).await; + + match &result { + Ok(_) => { + pb.finish(); + } + Err(_) => { + pb.abandon(); + } + } + + result.map_err(|e| { + DownloadError::NetworkError(format!("Failed to download {}: {}", filename, e)) + }) + }); + + tasks.push(task); + } + + // Wait for all downloads to complete + for task in tasks { + task.await + .map_err(|e| DownloadError::ApiError(format!("Task join error: {}", e)))??; + } + + let elapsed_time = start_time.elapsed(); + + println!( + "\n{} {} {} {} {:.2?}", + "βœ“".green().bold(), + "Successfully downloaded model".bright_white(), + name.cyan().bold(), + "in".bright_white(), + elapsed_time + ); + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_download_model_invalid() { + let downloader = HuggingFaceDownloader::new(); + let result = downloader + .download_model("invalid-model-that-does-not-exist-12345", &PathBuf::new()) + .await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_download_real_tiny_model() { + let downloader = HuggingFaceDownloader::new(); + // Use HF's official tiny test model (only a few KB) + let result = downloader + .download_model("InftyAI/tiny-random-gpt2", &PathBuf::new()) + .await; + assert!( + result.is_ok(), + "Failed to download tiny model: {:?}", + result + ); + + // Cleanup: remove the downloaded files from the default HF cache (~/.cache/huggingface/hub) + if let Some(home_dir) = dirs::home_dir() { + let cache_dir = home_dir + .join(".cache") + .join("huggingface") + .join("hub") + .join("models--InftyAI--tiny-random-gpt2"); + + if cache_dir.exists() { + let _ = std::fs::remove_dir_all(&cache_dir); + } + } + } + + #[tokio::test] + async fn test_download_with_custom_cache() { + use std::env; + use std::fs; + + let downloader = HuggingFaceDownloader::new(); + let temp_dir = env::temp_dir().join("puma_test_cache"); + + print!("Using temporary cache directory: {:?}\n", temp_dir); + + // Create the directory first + fs::create_dir_all(&temp_dir).unwrap(); + + let result = downloader + .download_model("InftyAI/tiny-random-gpt2", &temp_dir) + .await; + + assert!( + result.is_ok(), + "Failed to download with custom cache: {:?}", + result + ); + + // Cleanup + let _ = std::fs::remove_dir_all(&temp_dir); + } +} diff --git a/src/downloader/mod.rs b/src/downloader/mod.rs index a48aa6c..7af4570 100644 --- a/src/downloader/mod.rs +++ b/src/downloader/mod.rs @@ -1 +1,2 @@ -mod downloader; +pub mod downloader; +pub mod huggingface; diff --git a/src/main.rs b/src/main.rs index 31e44ba..ee5a89a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -10,7 +10,8 @@ use crate::cli::commands::{run, Cli}; use crate::util::file; fn main() { - env_logger::init(); + // Initialize logger. + env_logger::Builder::from_env(env_logger::Env::default()).init(); // Create the root folder if it doesn't exist. file::create_folder_if_not_exists(&file::root_home()).unwrap(); From f72c0202af9c6872da73b04d152c472393cc415d Mon Sep 17 00:00:00 2001 From: Kante Yin Date: Sun, 19 Apr 2026 19:30:40 +0100 Subject: [PATCH 02/11] Support `puma rm ` (#17) * support new cache structure Signed-off-by: kerthcet * support puma rm Signed-off-by: kerthcet * use readable format Signed-off-by: kerthcet * remove requests.rs Signed-off-by: kerthcet * fix lint Signed-off-by: kerthcet --------- Signed-off-by: kerthcet --- .github/workflows/rust-ci.yaml | 40 +++++ Cargo.lock | 117 ++++++++++++++- Cargo.toml | 5 + Makefile | 4 + README.md | 4 +- src/cli/commands.rs | 91 +++++++---- src/downloader/downloader.rs | 3 +- src/downloader/huggingface.rs | 125 ++++++++-------- src/downloader/mod.rs | 1 + src/lib.rs | 1 + src/main.rs | 2 +- src/registry/mod.rs | 1 + src/registry/model_registry.rs | 266 +++++++++++++++++++++++++++++++++ src/util/file.rs | 22 ++- src/util/format.rs | 212 ++++++++++++++++++++++++++ src/util/mod.rs | 2 +- src/util/request.rs | 153 ------------------- 17 files changed, 794 insertions(+), 255 deletions(-) create mode 100644 .github/workflows/rust-ci.yaml create mode 100644 src/registry/mod.rs create mode 100644 src/registry/model_registry.rs create mode 100644 src/util/format.rs delete mode 100644 src/util/request.rs diff --git a/.github/workflows/rust-ci.yaml b/.github/workflows/rust-ci.yaml new file mode 100644 index 0000000..0114db4 --- /dev/null +++ b/.github/workflows/rust-ci.yaml @@ -0,0 +1,40 @@ +name: Rust CI + +on: + push: + branches: [ main, feat/* ] + pull_request: + branches: [ main ] + +env: + CARGO_TERM_COLOR: always + +jobs: + lint: + name: Lint + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Install Rust toolchain + uses: actions-rust-lang/setup-rust-toolchain@v1 + with: + toolchain: stable + components: rustfmt, clippy + + - name: Run lint + run: make lint + + test: + name: Test + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Install Rust toolchain + uses: actions-rust-lang/setup-rust-toolchain@v1 + with: + toolchain: stable + + - name: Run tests + run: make test diff --git a/Cargo.lock b/Cargo.lock index 99f9700..c5d05ed 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -26,6 +26,15 @@ dependencies = [ "memchr", ] +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + [[package]] name = "anstream" version = "0.6.18" @@ -154,6 +163,19 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "chrono" +version = "0.4.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c673075a2e0e5f4a1dde27ce9dee1ea4558c7ffe648f576438a20ca1d2acc4b0" +dependencies = [ + "iana-time-zone", + "js-sys", + "num-traits", + "wasm-bindgen", + "windows-link", +] + [[package]] name = "clap" version = "4.5.30" @@ -794,6 +816,30 @@ dependencies = [ "tracing", ] +[[package]] +name = "iana-time-zone" +version = "0.1.65" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e31bc9ad994ba00e440a8aa5c9ef0ec67d5cb5e5cb0cc7f8b744a35b389cc470" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "log", + "wasm-bindgen", + "windows-core", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + [[package]] name = "icu_collections" version = "1.5.0" @@ -1120,6 +1166,15 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c6673768db2d862beb9b39a78fdcb1a69439615d5794a1be50caa9bc92c81967" +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + [[package]] name = "num_cpus" version = "1.17.0" @@ -1305,6 +1360,7 @@ dependencies = [ name = "puma" version = "0.0.1" dependencies = [ + "chrono", "clap", "colored", "dirs", @@ -1316,6 +1372,8 @@ dependencies = [ "reqwest", "serde", "serde_derive", + "serde_json", + "tempfile", "tokio", ] @@ -2269,6 +2327,41 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windows-core" +version = "0.62.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8e83a14d34d0623b51dce9581199302a221863196a1dde71a7663a4c2be9deb" +dependencies = [ + "windows-implement", + "windows-interface", + "windows-link", + "windows-result 0.4.1", + "windows-strings 0.5.1", +] + +[[package]] +name = "windows-implement" +version = "0.60.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "053e2e040ab57b9dc951b72c264860db7eb3b0200ba345b4e4c3b14f67855ddf" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "windows-interface" +version = "0.59.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f316c4a2570ba26bbec722032c4099d8c8bc095efccdc15688708623367e358" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "windows-link" version = "0.2.1" @@ -2281,8 +2374,8 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e400001bb720a623c1c69032f8e3e4cf09984deec740f007dd2b03ec864804b0" dependencies = [ - "windows-result", - "windows-strings", + "windows-result 0.2.0", + "windows-strings 0.1.0", "windows-targets", ] @@ -2295,16 +2388,34 @@ dependencies = [ "windows-targets", ] +[[package]] +name = "windows-result" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7781fa89eaf60850ac3d2da7af8e5242a5ea78d1a11c49bf2910bb5a73853eb5" +dependencies = [ + "windows-link", +] + [[package]] name = "windows-strings" version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4cd9b125c486025df0eabcb585e62173c6c9eddcec5d117d3b6e8c30e2ee4d10" dependencies = [ - "windows-result", + "windows-result 0.2.0", "windows-targets", ] +[[package]] +name = "windows-strings" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7837d08f69c77cf6b07689544538e017c1bfcf57e34b4c0ff58e6c2cd3b37091" +dependencies = [ + "windows-link", +] + [[package]] name = "windows-sys" version = "0.52.0" diff --git a/Cargo.toml b/Cargo.toml index f79acd5..bad4373 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,3 +19,8 @@ indicatif = "0.17.11" dirs = "6.0.0" hf-hub = { version = "0.5.0", features = ["tokio"] } colored = "2.1" +chrono = "0.4" +serde_json = "1.0" + +[dev-dependencies] +tempfile = "3.12" diff --git a/Makefile b/Makefile index 4fad7b1..61ed3f8 100644 --- a/Makefile +++ b/Makefile @@ -3,3 +3,7 @@ build: test: cargo test + +lint: + cargo fmt --all -- --check + cargo clippy --all-targets --all-features -- -D warnings diff --git a/README.md b/README.md index 3adf683..7b76557 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ ## Quick Start -### Installation +### Install from source ```bash make build @@ -34,7 +34,7 @@ make build ### Build ```bash -make compile +make build ``` ### Test diff --git a/src/cli/commands.rs b/src/cli/commands.rs index 6e86627..c4196b6 100644 --- a/src/cli/commands.rs +++ b/src/cli/commands.rs @@ -1,9 +1,10 @@ use clap::{Parser, Subcommand}; use prettytable::{format, row, Table}; -use std::path::PathBuf; use crate::downloader::downloader::Downloader; use crate::downloader::huggingface::HuggingFaceDownloader; +use crate::registry::model_registry::ModelRegistry; +use crate::util::format::{format_size, format_time_ago}; #[derive(Parser)] #[command(name = "PUMA")] @@ -14,6 +15,7 @@ pub struct Cli { } #[derive(Subcommand)] +#[allow(clippy::upper_case_acronyms)] enum Commands { /// List running models PS, @@ -26,7 +28,7 @@ enum Commands { /// Stop one running model STOP, /// Remove one model - RM, + RM(RmArgs), /// Display system-wide information INFO, /// Return detailed information about a model @@ -37,7 +39,7 @@ enum Commands { #[derive(Parser)] struct PullArgs { - #[arg(short = 'm', long, value_name = "model name")] + /// Model name to download (e.g., InftyAI/tiny-random-gpt2) model: String, #[arg( short = 'p', @@ -47,22 +49,21 @@ struct PullArgs { default_value = "huggingface" )] provider: Provider, - #[arg(long, value_name = "cache directory")] - cache_dir: Option, } -#[derive(Debug, Clone, clap::ValueEnum)] +#[derive(Parser)] +struct RmArgs { + /// Model name to remove (e.g., InftyAI/tiny-random-gpt2) + model: String, +} + +#[derive(Debug, Clone, Default, clap::ValueEnum)] pub enum Provider { + #[default] Huggingface, Modelscope, } -impl Default for Provider { - fn default() -> Self { - Provider::Huggingface - } -} - // Support commands like: pull, ls, run, ps, stop, rm, info, inspect, show. pub async fn run(cli: Cli) { match cli.command { @@ -82,29 +83,42 @@ pub async fn run(cli: Cli) { } Commands::LS => { + let registry = ModelRegistry::new(None); + let models = registry.load_models().unwrap_or_default(); + let mut table = Table::new(); table.set_format(*format::consts::FORMAT_CLEAN); - table.add_row(row!["MODEl", "PROVIDER", "REVISION", "SIZE", "CREATED"]); - table.add_row(row![ - "deepseek-ai/DeepSeek-R1", - "huggingface", - "main", - "80GB", - "2 weeks ago" - ]); + table.add_row(row!["MODEL", "PROVIDER", "REVISION", "SIZE", "CREATED"]); + + for model in models { + let size_str = format_size(model.size); + + let revision_short = if model.revision.len() > 8 { + &model.revision[..8] + } else { + &model.revision + }; + + let created_str = format_time_ago(&model.created_at); + + table.add_row(row![ + model.name, + model.provider, + revision_short, + size_str, + created_str + ]); + } + table.printstd(); } Commands::PULL(args) => match args.provider { Provider::Huggingface => { let downloader = HuggingFaceDownloader::new(); - let cache_dir = args.cache_dir.unwrap_or_else(|| PathBuf::new()); - match downloader.download_model(&args.model, &cache_dir).await { - Ok(_) => {} - Err(e) => { - eprintln!("Error downloading model: {}", e); - std::process::exit(1); - } + if let Err(e) = downloader.download_model(&args.model).await { + eprintln!("Error downloading model: {}", e); + std::process::exit(1); } } Provider::Modelscope => { @@ -120,8 +134,27 @@ pub async fn run(cli: Cli) { println!("Stopping one running model..."); } - Commands::RM => { - println!("Removing one model..."); + Commands::RM(args) => { + let registry = ModelRegistry::new(None); + + // Check if model exists first + match registry.get_model(&args.model) { + Ok(Some(_)) => { + // Delete model (cache + registry) + if let Err(e) = registry.remove_model(&args.model) { + eprintln!("Failed to remove model: {}", e); + std::process::exit(1); + } + } + Ok(None) => { + eprintln!("Model not found: {}", args.model); + std::process::exit(1); + } + Err(e) => { + eprintln!("Failed to load registry: {}", e); + std::process::exit(1); + } + } } Commands::INFO => { diff --git a/src/downloader/downloader.rs b/src/downloader/downloader.rs index 4b0e59e..21ae0b0 100644 --- a/src/downloader/downloader.rs +++ b/src/downloader/downloader.rs @@ -1,5 +1,4 @@ use core::fmt; -use std::path::PathBuf; #[derive(Debug)] pub enum DownloadError { @@ -25,5 +24,5 @@ impl fmt::Display for DownloadError { } pub trait Downloader { - async fn download_model(&self, name: &str, cache_dir: &PathBuf) -> Result<(), DownloadError>; + async fn download_model(&self, name: &str) -> Result<(), DownloadError>; } diff --git a/src/downloader/huggingface.rs b/src/downloader/huggingface.rs index d046b6b..baaa7cd 100644 --- a/src/downloader/huggingface.rs +++ b/src/downloader/huggingface.rs @@ -1,16 +1,19 @@ use colored::Colorize; use log::{debug, info}; -use std::path::PathBuf; +use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; use hf_hub::api::tokio::{ApiBuilder, Progress}; use indicatif::{MultiProgress, ProgressBar, ProgressStyle}; use crate::downloader::downloader::{DownloadError, Downloader}; +use crate::registry::model_registry::{ModelInfo, ModelRegistry}; +use crate::util::file; #[derive(Clone)] struct FileProgressBar { pb: ProgressBar, + total_size: Arc, } impl Progress for FileProgressBar { @@ -18,6 +21,7 @@ impl Progress for FileProgressBar { self.pb.set_length(size as u64); self.pb.reset(); self.pb.tick(); // Force render with correct size + self.total_size.fetch_add(size as u64, Ordering::Relaxed); } async fn update(&mut self, size: usize) { @@ -42,29 +46,24 @@ impl Default for HuggingFaceDownloader { } impl Downloader for HuggingFaceDownloader { - async fn download_model(&self, name: &str, cache_dir: &PathBuf) -> Result<(), DownloadError> { + async fn download_model(&self, name: &str) -> Result<(), DownloadError> { let start_time = std::time::Instant::now(); info!("Downloading model {} from Hugging Face...", name); - // Build API without default progress bars (we have our own implementation) - let api = if cache_dir.as_os_str().is_empty() { - // Use default HF cache - ApiBuilder::new().build().map_err(|e| { + // Use unified PUMA cache directory + let cache_dir = file::huggingface_cache_dir(); + file::create_folder_if_not_exists(&cache_dir).map_err(|e| { + DownloadError::IoError(format!("Failed to create cache directory: {}", e)) + })?; + + // Build API with PUMA cache directory + let api = ApiBuilder::new() + .with_cache_dir(cache_dir.clone()) + .build() + .map_err(|e| { DownloadError::ApiError(format!("Failed to initialize Hugging Face API: {}", e)) - })? - } else { - // Use custom cache directory - ApiBuilder::new() - .with_cache_dir(cache_dir.clone()) - .build() - .map_err(|e| { - DownloadError::ApiError(format!( - "Failed to initialize Hugging Face API with custom cache: {}", - e - )) - })? - }; + })?; // Download the entire model repository using snapshot download let repo = api.model(name.to_string()); @@ -88,19 +87,34 @@ impl Downloader for HuggingFaceDownloader { // Create multi-progress for parallel downloads let multi_progress = Arc::new(MultiProgress::new()); + // Calculate the longest filename for proper alignment + let max_filename_len = model_info + .siblings + .iter() + .map(|s| s.rfilename.len()) + .max() + .unwrap_or(30); + // Progress bar style with block characters (chart-like, not #) + let template = format!( + "{{msg:<{width}}} [{{elapsed_precise}}] {{bar:60.white}} {{bytes}}/{{total_bytes}}", + width = max_filename_len + ); let style = ProgressStyle::default_bar() - .template("{msg:<30} [{elapsed_precise}] {bar:60.white} {bytes}/{total_bytes}") + .template(&template) .unwrap() .progress_chars("▇▆▅▄▃▂▁ "); // Download all files in parallel let mut tasks = Vec::new(); + let sha = model_info.sha.clone(); + let total_size = Arc::new(AtomicU64::new(0)); for sibling in model_info.siblings { let api_clone = api.clone(); let model_name = name.to_string(); let filename = sibling.rfilename.clone(); + let total_size_clone = Arc::clone(&total_size); let pb = multi_progress.add(ProgressBar::hidden()); pb.set_style(style.clone()); @@ -110,7 +124,10 @@ impl Downloader for HuggingFaceDownloader { debug!("Downloading: {}", filename); let repo = api_clone.model(model_name); - let progress = FileProgressBar { pb: pb.clone() }; + let progress = FileProgressBar { + pb: pb.clone(), + total_size: total_size_clone, + }; let result = repo.download_with_progress(&filename, progress).await; @@ -139,6 +156,25 @@ impl Downloader for HuggingFaceDownloader { let elapsed_time = start_time.elapsed(); + // Get accumulated size from downloads + let downloaded_size = total_size.load(Ordering::Relaxed); + let model_cache_path = cache_dir.join(format!("models--{}", name.replace("/", "--"))); + + // Register the model + let model_info_record = ModelInfo { + name: name.to_string(), + provider: "huggingface".to_string(), + revision: sha, + size: downloaded_size, + created_at: chrono::Local::now().to_rfc3339(), + cache_path: model_cache_path.to_string_lossy().to_string(), + }; + + let registry = ModelRegistry::new(None); + registry + .register_model(model_info_record) + .map_err(|e| DownloadError::ApiError(format!("Failed to register model: {}", e)))?; + println!( "\n{} {} {} {} {:.2?}", "βœ“".green().bold(), @@ -160,7 +196,7 @@ mod tests { async fn test_download_model_invalid() { let downloader = HuggingFaceDownloader::new(); let result = downloader - .download_model("invalid-model-that-does-not-exist-12345", &PathBuf::new()) + .download_model("invalid-model-that-does-not-exist-12345") .await; assert!(result.is_err()); } @@ -169,53 +205,18 @@ mod tests { async fn test_download_real_tiny_model() { let downloader = HuggingFaceDownloader::new(); // Use HF's official tiny test model (only a few KB) - let result = downloader - .download_model("InftyAI/tiny-random-gpt2", &PathBuf::new()) - .await; + let result = downloader.download_model("InftyAI/tiny-random-gpt2").await; assert!( result.is_ok(), "Failed to download tiny model: {:?}", result ); - // Cleanup: remove the downloaded files from the default HF cache (~/.cache/huggingface/hub) - if let Some(home_dir) = dirs::home_dir() { - let cache_dir = home_dir - .join(".cache") - .join("huggingface") - .join("hub") - .join("models--InftyAI--tiny-random-gpt2"); + // Cleanup: remove the downloaded files from PUMA cache + let cache_dir = file::huggingface_cache_dir().join("models--InftyAI--tiny-random-gpt2"); - if cache_dir.exists() { - let _ = std::fs::remove_dir_all(&cache_dir); - } + if cache_dir.exists() { + let _ = std::fs::remove_dir_all(&cache_dir); } } - - #[tokio::test] - async fn test_download_with_custom_cache() { - use std::env; - use std::fs; - - let downloader = HuggingFaceDownloader::new(); - let temp_dir = env::temp_dir().join("puma_test_cache"); - - print!("Using temporary cache directory: {:?}\n", temp_dir); - - // Create the directory first - fs::create_dir_all(&temp_dir).unwrap(); - - let result = downloader - .download_model("InftyAI/tiny-random-gpt2", &temp_dir) - .await; - - assert!( - result.is_ok(), - "Failed to download with custom cache: {:?}", - result - ); - - // Cleanup - let _ = std::fs::remove_dir_all(&temp_dir); - } } diff --git a/src/downloader/mod.rs b/src/downloader/mod.rs index 7af4570..2bdb3f5 100644 --- a/src/downloader/mod.rs +++ b/src/downloader/mod.rs @@ -1,2 +1,3 @@ +#[allow(clippy::module_inception)] pub mod downloader; pub mod huggingface; diff --git a/src/lib.rs b/src/lib.rs index e69de29..8b13789 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -0,0 +1 @@ + diff --git a/src/main.rs b/src/main.rs index ee5a89a..e280e53 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,9 +1,9 @@ mod cli; mod downloader; +mod registry; mod util; use clap::Parser; -use env_logger; use tokio::runtime::Builder; use crate::cli::commands::{run, Cli}; diff --git a/src/registry/mod.rs b/src/registry/mod.rs new file mode 100644 index 0000000..8565989 --- /dev/null +++ b/src/registry/mod.rs @@ -0,0 +1 @@ +pub mod model_registry; diff --git a/src/registry/model_registry.rs b/src/registry/model_registry.rs new file mode 100644 index 0000000..ccc1eab --- /dev/null +++ b/src/registry/model_registry.rs @@ -0,0 +1,266 @@ +use colored::Colorize; +use serde::{Deserialize, Serialize}; +use std::fs; +use std::path::PathBuf; + +use crate::util::file; + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ModelInfo { + pub name: String, + pub provider: String, + pub revision: String, + pub size: u64, + pub created_at: String, + pub cache_path: String, +} + +pub struct ModelRegistry { + home_dir: PathBuf, +} + +impl ModelRegistry { + pub fn new(home_dir: Option) -> Self { + Self { + home_dir: home_dir.unwrap_or_else(file::root_home), + } + } + + fn registry_file(&self) -> PathBuf { + self.home_dir.join("models.json") + } + + pub fn load_models(&self) -> Result, std::io::Error> { + let registry_file = self.registry_file(); + + if !registry_file.exists() { + return Ok(Vec::new()); + } + + let contents = fs::read_to_string(registry_file)?; + let models: Vec = serde_json::from_str(&contents) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + + Ok(models) + } + + pub fn save_models(&self, models: &[ModelInfo]) -> Result<(), std::io::Error> { + // Ensure home directory exists + fs::create_dir_all(&self.home_dir)?; + + let registry_file = self.registry_file(); + let json = serde_json::to_string_pretty(models) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + + fs::write(registry_file, json)?; + Ok(()) + } + + pub fn register_model(&self, model: ModelInfo) -> Result<(), std::io::Error> { + let mut models = self.load_models()?; + + // Remove existing model with same name if exists + models.retain(|m| m.name != model.name); + + models.push(model); + self.save_models(&models)?; + + Ok(()) + } + + pub fn unregister_model(&self, name: &str) -> Result<(), std::io::Error> { + let mut models = self.load_models()?; + models.retain(|m| m.name != name); + self.save_models(&models)?; + + Ok(()) + } + + pub fn get_model(&self, name: &str) -> Result, std::io::Error> { + let models = self.load_models()?; + Ok(models.into_iter().find(|m| m.name == name)) + } + + pub fn remove_model(&self, name: &str) -> Result<(), std::io::Error> { + // Get model info first + let model_info = self.get_model(name)?; + + if let Some(info) = model_info { + // Delete cache directory if it exists + let cache_path = std::path::Path::new(&info.cache_path); + if cache_path.exists() { + fs::remove_dir_all(cache_path)?; + } + + // Remove from registry + self.unregister_model(name)?; + + println!( + "\n{} {} {}", + "βœ“".green().bold(), + "Successfully removed model".bright_white(), + name.cyan().bold() + ); + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::TempDir; + + #[test] + fn test_add_and_load_model() { + let temp_dir = TempDir::new().unwrap(); + let registry = ModelRegistry::new(Some(temp_dir.path().to_path_buf())); + + let model = ModelInfo { + name: "test/model".to_string(), + provider: "huggingface".to_string(), + revision: "abc123".to_string(), + size: 1000, + created_at: "2025-01-01T00:00:00Z".to_string(), + cache_path: "/tmp/test".to_string(), + }; + + registry.register_model(model.clone()).unwrap(); + + let models = registry.load_models().unwrap(); + assert_eq!(models.len(), 1); + assert_eq!(models[0].name, "test/model"); + } + + #[test] + fn test_unregister_model() { + let temp_dir = TempDir::new().unwrap(); + let registry = ModelRegistry::new(Some(temp_dir.path().to_path_buf())); + + let model = ModelInfo { + name: "test/model".to_string(), + provider: "huggingface".to_string(), + revision: "abc123".to_string(), + size: 1000, + created_at: "2025-01-01T00:00:00Z".to_string(), + cache_path: "/tmp/test".to_string(), + }; + + registry.register_model(model).unwrap(); + assert_eq!(registry.load_models().unwrap().len(), 1); + + registry.unregister_model("test/model").unwrap(); + assert_eq!(registry.load_models().unwrap().len(), 0); + } + + #[test] + fn test_get_model() { + let temp_dir = TempDir::new().unwrap(); + let registry = ModelRegistry::new(Some(temp_dir.path().to_path_buf())); + + let model = ModelInfo { + name: "test/model".to_string(), + provider: "huggingface".to_string(), + revision: "abc123".to_string(), + size: 1000, + created_at: "2025-01-01T00:00:00Z".to_string(), + cache_path: "/tmp/test".to_string(), + }; + + registry.register_model(model).unwrap(); + + let result = registry.get_model("test/model").unwrap(); + assert!(result.is_some()); + assert_eq!(result.unwrap().name, "test/model"); + + let not_found = registry.get_model("nonexistent").unwrap(); + assert!(not_found.is_none()); + } + + #[test] + fn test_remove_nonexistent_model() { + let temp_dir = TempDir::new().unwrap(); + let registry = ModelRegistry::new(Some(temp_dir.path().to_path_buf())); + + // Should not error when removing non-existent model + let result = registry.unregister_model("nonexistent"); + assert!(result.is_ok()); + } + + #[test] + fn test_update_existing_model() { + let temp_dir = TempDir::new().unwrap(); + let registry = ModelRegistry::new(Some(temp_dir.path().to_path_buf())); + + let model1 = ModelInfo { + name: "test/model".to_string(), + provider: "huggingface".to_string(), + revision: "abc123".to_string(), + size: 1000, + created_at: "2025-01-01T00:00:00Z".to_string(), + cache_path: "/tmp/test".to_string(), + }; + + registry.register_model(model1).unwrap(); + + let model2 = ModelInfo { + name: "test/model".to_string(), + provider: "huggingface".to_string(), + revision: "def456".to_string(), + size: 2000, + created_at: "2025-01-02T00:00:00Z".to_string(), + cache_path: "/tmp/test2".to_string(), + }; + + registry.register_model(model2).unwrap(); + + let models = registry.load_models().unwrap(); + assert_eq!(models.len(), 1); + assert_eq!(models[0].revision, "def456"); + assert_eq!(models[0].size, 2000); + } + + #[test] + fn test_remove_model_with_cache() { + let temp_dir = TempDir::new().unwrap(); + let registry = ModelRegistry::new(Some(temp_dir.path().to_path_buf())); + + // Create a fake cache directory + let cache_dir = temp_dir.path().join("cache"); + fs::create_dir_all(&cache_dir).unwrap(); + fs::write(cache_dir.join("test.txt"), "test data").unwrap(); + + let model = ModelInfo { + name: "test/model".to_string(), + provider: "huggingface".to_string(), + revision: "abc123".to_string(), + size: 1000, + created_at: "2025-01-01T00:00:00Z".to_string(), + cache_path: cache_dir.to_string_lossy().to_string(), + }; + + registry.register_model(model).unwrap(); + assert_eq!(registry.load_models().unwrap().len(), 1); + assert!(cache_dir.exists()); + + // Delete model + registry.remove_model("test/model").unwrap(); + + // Verify model removed from registry + assert_eq!(registry.load_models().unwrap().len(), 0); + + // Verify cache directory deleted + assert!(!cache_dir.exists()); + } + + #[test] + fn test_delete_nonexistent_model() { + let temp_dir = TempDir::new().unwrap(); + let registry = ModelRegistry::new(Some(temp_dir.path().to_path_buf())); + + // Should not error when deleting non-existent model + let result = registry.remove_model("nonexistent"); + assert!(result.is_ok()); + } +} diff --git a/src/util/file.rs b/src/util/file.rs index 26effb0..602a1ef 100644 --- a/src/util/file.rs +++ b/src/util/file.rs @@ -9,6 +9,24 @@ pub fn create_folder_if_not_exists(folder_path: &PathBuf) -> std::io::Result<()> } pub fn root_home() -> PathBuf { - let home = home_dir().expect("Failed to get home directory"); - home.join(".puma") + // Allow tests to override PUMA home directory + if let Ok(test_home) = std::env::var("PUMA_HOME") { + PathBuf::from(test_home) + } else { + let home = home_dir().expect("Failed to get home directory"); + home.join(".puma") + } +} + +pub fn cache_dir() -> PathBuf { + root_home().join("cache") +} + +pub fn huggingface_cache_dir() -> PathBuf { + cache_dir().join("huggingface") +} + +#[allow(dead_code)] +pub fn modelscope_cache_dir() -> PathBuf { + cache_dir().join("modelscope") } diff --git a/src/util/format.rs b/src/util/format.rs new file mode 100644 index 0000000..72ce86d --- /dev/null +++ b/src/util/format.rs @@ -0,0 +1,212 @@ +use chrono::{DateTime, Utc}; + +/// Format byte size to human-readable format (B, KB, MB, GB) +pub fn format_size(bytes: u64) -> String { + if bytes > 1_000_000_000 { + format!("{:.2} GB", bytes as f64 / 1_000_000_000.0) + } else if bytes > 1_000_000 { + format!("{:.2} MB", bytes as f64 / 1_000_000.0) + } else if bytes > 1_000 { + format!("{:.2} KB", bytes as f64 / 1_000.0) + } else { + format!("{} B", bytes) + } +} + +/// Format RFC3339 timestamp to human-readable relative time (e.g., "2 hours ago") +pub fn format_time_ago(timestamp: &str) -> String { + // Try to parse as RFC3339 + let created_time = match DateTime::parse_from_rfc3339(timestamp) { + Ok(dt) => dt.with_timezone(&Utc), + Err(_) => return timestamp.to_string(), // Return original if parse fails + }; + + let now = Utc::now(); + let duration = now.signed_duration_since(created_time); + + let seconds = duration.num_seconds(); + + if seconds < 0 { + "just now".to_string() + } else if seconds < 60 { + format!("{} seconds ago", seconds) + } else if seconds < 3600 { + let minutes = seconds / 60; + format!( + "{} {} ago", + minutes, + if minutes == 1 { "minute" } else { "minutes" } + ) + } else if seconds < 86400 { + let hours = seconds / 3600; + format!( + "{} {} ago", + hours, + if hours == 1 { "hour" } else { "hours" } + ) + } else if seconds < 2592000 { + let days = seconds / 86400; + format!("{} {} ago", days, if days == 1 { "day" } else { "days" }) + } else if seconds < 31536000 { + let months = seconds / 2592000; + format!( + "{} {} ago", + months, + if months == 1 { "month" } else { "months" } + ) + } else { + let years = seconds / 31536000; + format!( + "{} {} ago", + years, + if years == 1 { "year" } else { "years" } + ) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_format_size_bytes() { + assert_eq!(format_size(0), "0 B"); + assert_eq!(format_size(1), "1 B"); + assert_eq!(format_size(999), "999 B"); + assert_eq!(format_size(1000), "1000 B"); + } + + #[test] + fn test_format_size_kilobytes() { + assert_eq!(format_size(1_001), "1.00 KB"); + assert_eq!(format_size(1_500), "1.50 KB"); + assert_eq!(format_size(10_000), "10.00 KB"); + assert_eq!(format_size(999_999), "1000.00 KB"); + } + + #[test] + fn test_format_size_megabytes() { + assert_eq!(format_size(1_000_001), "1.00 MB"); + assert_eq!(format_size(1_500_000), "1.50 MB"); + assert_eq!(format_size(10_000_000), "10.00 MB"); + assert_eq!(format_size(500_000_000), "500.00 MB"); + } + + #[test] + fn test_format_size_gigabytes() { + assert_eq!(format_size(1_000_000_001), "1.00 GB"); + assert_eq!(format_size(1_500_000_000), "1.50 GB"); + assert_eq!(format_size(10_000_000_000), "10.00 GB"); + assert_eq!(format_size(100_000_000_000), "100.00 GB"); + } + + #[test] + fn test_format_size_edge_cases() { + // Boundary between KB and MB + assert_eq!(format_size(1_000_000), "1000.00 KB"); + assert_eq!(format_size(1_000_001), "1.00 MB"); + + // Boundary between MB and GB + assert_eq!(format_size(1_000_000_000), "1000.00 MB"); + assert_eq!(format_size(1_000_000_001), "1.00 GB"); + } + + #[test] + fn test_format_size_realistic_model_sizes() { + // Small model (100 MB) + assert_eq!(format_size(104_857_600), "104.86 MB"); + + // Medium model (7 GB) + assert_eq!(format_size(7_516_192_768), "7.52 GB"); + + // Large model (65 GB) + assert_eq!(format_size(69_793_218_560), "69.79 GB"); + } + + #[test] + fn test_format_time_ago_seconds() { + use chrono::Duration; + + let now = Utc::now(); + let timestamp = (now - Duration::seconds(30)).to_rfc3339(); + assert_eq!(format_time_ago(×tamp), "30 seconds ago"); + + let timestamp = (now - Duration::seconds(1)).to_rfc3339(); + assert_eq!(format_time_ago(×tamp), "1 seconds ago"); + } + + #[test] + fn test_format_time_ago_minutes() { + use chrono::Duration; + + let now = Utc::now(); + let timestamp = (now - Duration::minutes(5)).to_rfc3339(); + assert_eq!(format_time_ago(×tamp), "5 minutes ago"); + + let timestamp = (now - Duration::minutes(1)).to_rfc3339(); + assert_eq!(format_time_ago(×tamp), "1 minute ago"); + } + + #[test] + fn test_format_time_ago_hours() { + use chrono::Duration; + + let now = Utc::now(); + let timestamp = (now - Duration::hours(3)).to_rfc3339(); + assert_eq!(format_time_ago(×tamp), "3 hours ago"); + + let timestamp = (now - Duration::hours(1)).to_rfc3339(); + assert_eq!(format_time_ago(×tamp), "1 hour ago"); + } + + #[test] + fn test_format_time_ago_days() { + use chrono::Duration; + + let now = Utc::now(); + let timestamp = (now - Duration::days(7)).to_rfc3339(); + assert_eq!(format_time_ago(×tamp), "7 days ago"); + + let timestamp = (now - Duration::days(1)).to_rfc3339(); + assert_eq!(format_time_ago(×tamp), "1 day ago"); + } + + #[test] + fn test_format_time_ago_months() { + use chrono::Duration; + + let now = Utc::now(); + let timestamp = (now - Duration::days(60)).to_rfc3339(); + assert_eq!(format_time_ago(×tamp), "2 months ago"); + + let timestamp = (now - Duration::days(30)).to_rfc3339(); + assert_eq!(format_time_ago(×tamp), "1 month ago"); + } + + #[test] + fn test_format_time_ago_years() { + use chrono::Duration; + + let now = Utc::now(); + let timestamp = (now - Duration::days(730)).to_rfc3339(); + assert_eq!(format_time_ago(×tamp), "2 years ago"); + + let timestamp = (now - Duration::days(365)).to_rfc3339(); + assert_eq!(format_time_ago(×tamp), "1 year ago"); + } + + #[test] + fn test_format_time_ago_future() { + use chrono::Duration; + + let now = Utc::now(); + let timestamp = (now + Duration::hours(5)).to_rfc3339(); + assert_eq!(format_time_ago(×tamp), "just now"); + } + + #[test] + fn test_format_time_ago_invalid() { + let invalid = "not-a-timestamp"; + assert_eq!(format_time_ago(invalid), "not-a-timestamp"); + } +} diff --git a/src/util/mod.rs b/src/util/mod.rs index c756da3..73f5099 100644 --- a/src/util/mod.rs +++ b/src/util/mod.rs @@ -1,2 +1,2 @@ pub mod file; -pub mod request; +pub mod format; diff --git a/src/util/request.rs b/src/util/request.rs deleted file mode 100644 index a82369d..0000000 --- a/src/util/request.rs +++ /dev/null @@ -1,153 +0,0 @@ -use core::time; -use std::error::Error; -use std::fs::File; -use std::io; -use std::os::unix::fs::FileExt; -use std::path::PathBuf; -use std::sync::Arc; - -use indicatif::{MultiProgress, ProgressBar, ProgressStyle}; -use log::{debug, error}; -use reqwest::Client; -use tokio::sync::Semaphore; - -const MAX_CHUNK_CONCURRENCY: usize = 20; -const CHUNK_SIZE: usize = 1000 * 1000 * 50; // 50MB -const MAX_RETRIES: usize = 5; -const SLEEP_FACTOR: usize = 500; // 500 ms - -pub async fn download_file( - client: Arc, - url: String, - content_length: u64, - filename: String, - output_path: &PathBuf, - m: Arc, - sty: ProgressStyle, -) -> Result<(), Box> { - debug!( - "Start to download file {} to {}", - filename, - output_path.display() - ); - - let mut tasks = Vec::new(); - let mut start = 0; - let mut end = CHUNK_SIZE as u64 - 1; - end = end.min(content_length - 1); - - let semaphore = Arc::new(Semaphore::new(MAX_CHUNK_CONCURRENCY)); - // TODO: verify the file not downloaded yet. - let file = Arc::new(File::create(&output_path)?); - let arc_url = Arc::new(url); - - let pb = m.add(ProgressBar::new(content_length).with_style(sty)); - pb.set_message(filename.clone()); - let arc_pb = Arc::new(pb); - - while start < content_length { - let client = Arc::clone(&client); - let semaphore = Arc::clone(&semaphore); - let file = Arc::clone(&file); - let url = Arc::clone(&arc_url); - let pb = Arc::clone(&arc_pb); - - let fname = filename.clone(); - - let task = tokio::spawn(async move { - let _permit = semaphore.acquire().await.unwrap(); - let _ = download_chunk_with_retries( - client, - file, - fname, - url, - start.clone(), - end.clone(), - MAX_RETRIES, - ) - .await; - - pb.inc(end - start + 1); - }); - tasks.push(task); - - start = end + 1; - end = (end + CHUNK_SIZE as u64).min(content_length - 1); - } - - for task in tasks { - let _ = task.await; - // TODO: write to a file about the chunk info. - } - - arc_pb.finish(); - Ok(()) -} - -async fn download_chunk_with_retries( - client: Arc, - file: Arc, - filename: String, - url: Arc, - start: u64, - end: u64, - retries: usize, -) -> Result<(), Box> { - debug!("Start to download chunk {}:{}-{}", filename, start, end,); - - let mut retries = retries; - loop { - match download_chunk(&client, &file, &url, start, end).await { - Ok(_) => { - debug!("Download chunk {}:{}-{} successfully", filename, start, end); - break; - } - // TODO: retry only when http error. - Err(e) => { - if retries == 0 { - error!("Reach the maximum retries {}. Return", MAX_RETRIES); - return Err(e); - } - retries -= 1; - - let _ = tokio::time::sleep(time::Duration::from_millis( - SLEEP_FACTOR as u64 * 2u64.pow((MAX_RETRIES - retries) as u32), - )); - - error!( - "Failed to download chunk {}:{}-{}, err: {}, retrying {}...", - filename, - start, - end, - e.to_string(), - MAX_RETRIES - retries - ); - } - } - } - - Ok(()) -} - -async fn download_chunk( - client: &Client, - file: &File, - url: &str, - start: u64, - end: u64, -) -> Result<(), Box> { - let response = client - .get(url) - .header("Range", format!("bytes={}-{}", start, end)) - .send() - .await - .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; - - let chunk = response - .bytes() - .await - .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; - - file.write_all_at(&chunk, start)?; - Ok(()) -} From 6135ac3da1b6288bc9c9d04b5725e77d2440f6fc Mon Sep 17 00:00:00 2001 From: Kante Yin Date: Sun, 19 Apr 2026 21:08:52 +0100 Subject: [PATCH 03/11] support puma info (#18) Signed-off-by: kerthcet --- Cargo.lock | 134 +++++++++++++++++++++++++++++++++++++- Cargo.toml | 1 + src/cli/commands.rs | 4 +- src/main.rs | 1 + src/system/mod.rs | 1 + src/system/system_info.rs | 107 ++++++++++++++++++++++++++++++ src/util/format.rs | 68 ++++++++++--------- 7 files changed, 280 insertions(+), 36 deletions(-) create mode 100644 src/system/mod.rs create mode 100644 src/system/system_info.rs diff --git a/Cargo.lock b/Cargo.lock index c5d05ed..6554c89 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -311,6 +311,31 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "crossbeam-deque" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" + [[package]] name = "csv" version = "1.3.1" @@ -413,6 +438,12 @@ dependencies = [ "litrs", ] +[[package]] +name = "either" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" + [[package]] name = "encode_unicode" version = "1.0.0" @@ -828,7 +859,7 @@ dependencies = [ "js-sys", "log", "wasm-bindgen", - "windows-core", + "windows-core 0.62.2", ] [[package]] @@ -1160,6 +1191,15 @@ dependencies = [ "tempfile", ] +[[package]] +name = "ntapi" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3b335231dfd352ffb0f8017f3b6027a4917f7df785ea2143d8af2adc66980ae" +dependencies = [ + "winapi", +] + [[package]] name = "num-conv" version = "0.2.1" @@ -1373,6 +1413,7 @@ dependencies = [ "serde", "serde_derive", "serde_json", + "sysinfo", "tempfile", "tokio", ] @@ -1415,6 +1456,26 @@ dependencies = [ "getrandom 0.3.1", ] +[[package]] +name = "rayon" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb39b166781f92d482534ef4b4b1b2568f42613b53e5b6c160e24cfbfa30926d" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + [[package]] name = "redox_syscall" version = "0.5.9" @@ -1805,6 +1866,20 @@ dependencies = [ "syn", ] +[[package]] +name = "sysinfo" +version = "0.32.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c33cd241af0f2e9e3b5c32163b873b29956890b5342e6745b917ce9d490f4af" +dependencies = [ + "core-foundation-sys", + "libc", + "memchr", + "ntapi", + "rayon", + "windows", +] + [[package]] name = "system-configuration" version = "0.6.1" @@ -2327,19 +2402,52 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windows" +version = "0.57.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "12342cb4d8e3b046f3d80effd474a7a02447231330ef77d71daa6fbc40681143" +dependencies = [ + "windows-core 0.57.0", + "windows-targets", +] + +[[package]] +name = "windows-core" +version = "0.57.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2ed2439a290666cd67ecce2b0ffaad89c2a56b976b736e6ece670297897832d" +dependencies = [ + "windows-implement 0.57.0", + "windows-interface 0.57.0", + "windows-result 0.1.2", + "windows-targets", +] + [[package]] name = "windows-core" version = "0.62.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b8e83a14d34d0623b51dce9581199302a221863196a1dde71a7663a4c2be9deb" dependencies = [ - "windows-implement", - "windows-interface", + "windows-implement 0.60.2", + "windows-interface 0.59.3", "windows-link", "windows-result 0.4.1", "windows-strings 0.5.1", ] +[[package]] +name = "windows-implement" +version = "0.57.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9107ddc059d5b6fbfbffdfa7a7fe3e22a226def0b2608f72e9d552763d3e1ad7" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "windows-implement" version = "0.60.2" @@ -2351,6 +2459,17 @@ dependencies = [ "syn", ] +[[package]] +name = "windows-interface" +version = "0.57.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29bee4b38ea3cde66011baa44dba677c432a78593e202392d1e9070cf2a7fca7" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "windows-interface" version = "0.59.3" @@ -2379,6 +2498,15 @@ dependencies = [ "windows-targets", ] +[[package]] +name = "windows-result" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e383302e8ec8515204254685643de10811af0ed97ea37210dc26fb0032647f8" +dependencies = [ + "windows-targets", +] + [[package]] name = "windows-result" version = "0.2.0" diff --git a/Cargo.toml b/Cargo.toml index bad4373..eab4ff7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,6 +21,7 @@ hf-hub = { version = "0.5.0", features = ["tokio"] } colored = "2.1" chrono = "0.4" serde_json = "1.0" +sysinfo = "0.32" [dev-dependencies] tempfile = "3.12" diff --git a/src/cli/commands.rs b/src/cli/commands.rs index c4196b6..8a04b59 100644 --- a/src/cli/commands.rs +++ b/src/cli/commands.rs @@ -4,6 +4,7 @@ use prettytable::{format, row, Table}; use crate::downloader::downloader::Downloader; use crate::downloader::huggingface::HuggingFaceDownloader; use crate::registry::model_registry::ModelRegistry; +use crate::system::system_info::SystemInfo; use crate::util::format::{format_size, format_time_ago}; #[derive(Parser)] @@ -158,7 +159,8 @@ pub async fn run(cli: Cli) { } Commands::INFO => { - println!("Displaying system-wide information..."); + let info = SystemInfo::collect(); + info.display(); } Commands::INSPECT => { diff --git a/src/main.rs b/src/main.rs index e280e53..02d7f23 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,7 @@ mod cli; mod downloader; mod registry; +mod system; mod util; use clap::Parser; diff --git a/src/system/mod.rs b/src/system/mod.rs new file mode 100644 index 0000000..11b5b6f --- /dev/null +++ b/src/system/mod.rs @@ -0,0 +1 @@ +pub mod system_info; diff --git a/src/system/system_info.rs b/src/system/system_info.rs new file mode 100644 index 0000000..8abb7df --- /dev/null +++ b/src/system/system_info.rs @@ -0,0 +1,107 @@ +use serde::{Deserialize, Serialize}; +use std::fs; +use std::os::unix::fs::MetadataExt; +use std::path::PathBuf; +use sysinfo::System; + +use crate::registry::model_registry::ModelRegistry; +use crate::util::file; +use crate::util::format::format_size; + +#[derive(Debug, Serialize, Deserialize)] +pub struct SystemInfo { + pub version: String, + pub os: String, + pub architecture: String, + pub cpu_cores: usize, + pub total_memory: String, + pub available_memory: String, + pub cache_dir: String, + pub cache_size: String, + pub models_count: usize, + pub running_models: usize, +} + +impl SystemInfo { + pub fn collect() -> Self { + let mut sys = System::new_all(); + sys.refresh_all(); + + let cache_dir = file::cache_dir(); + let cache_size = Self::calculate_cache_size(&cache_dir); + + let registry = ModelRegistry::new(None); + let models_count = registry.load_models().unwrap_or_default().len(); + + SystemInfo { + version: env!("CARGO_PKG_VERSION").to_string(), + os: System::name().unwrap_or_else(|| "Unknown".to_string()), + architecture: System::cpu_arch().unwrap_or_else(|| "Unknown".to_string()), + cpu_cores: sys.cpus().len(), + total_memory: format_size(sys.total_memory()), + available_memory: format_size(sys.available_memory()), + cache_dir: cache_dir.to_string_lossy().to_string(), + cache_size: format_size(cache_size), + models_count, + running_models: 0, // TODO: implement running models tracking + } + } + + fn calculate_cache_size(cache_dir: &PathBuf) -> u64 { + if !cache_dir.exists() { + return 0; + } + + let mut total_size = 0u64; + + if let Ok(entries) = fs::read_dir(cache_dir) { + for entry in entries.flatten() { + if let Ok(metadata) = entry.metadata() { + if metadata.is_file() { + // Use blocks * 512 to get actual disk usage (handles sparse files) + total_size += metadata.blocks() * 512; + } else if metadata.is_dir() { + total_size += Self::dir_size(&entry.path()); + } + } + } + } + + total_size + } + + fn dir_size(path: &PathBuf) -> u64 { + let mut total_size = 0u64; + + if let Ok(entries) = fs::read_dir(path) { + for entry in entries.flatten() { + if let Ok(metadata) = entry.metadata() { + if metadata.is_file() { + // Use blocks * 512 to get actual disk usage (handles sparse files) + total_size += metadata.blocks() * 512; + } else if metadata.is_dir() { + total_size += Self::dir_size(&entry.path()); + } + } + } + } + + total_size + } + + pub fn display(&self) { + println!("System Information:"); + println!(" Operating System: {}", self.os); + println!(" Architecture: {}", self.architecture); + println!(" CPU Cores: {}", self.cpu_cores); + println!(" Total Memory: {}", self.total_memory); + println!(" Available Memory: {}", self.available_memory); + println!(); + println!("PUMA Information:"); + println!(" PUMA Version: {}", self.version); + println!(" Cache Directory: {}", self.cache_dir); + println!(" Cache Size: {}", self.cache_size); + println!(" Models: {}", self.models_count); + println!(" Running Models: {}", self.running_models); + } +} diff --git a/src/util/format.rs b/src/util/format.rs index 72ce86d..70e109b 100644 --- a/src/util/format.rs +++ b/src/util/format.rs @@ -1,13 +1,17 @@ use chrono::{DateTime, Utc}; -/// Format byte size to human-readable format (B, KB, MB, GB) +/// Format byte size to human-readable format (B, KiB, MiB, GiB) pub fn format_size(bytes: u64) -> String { - if bytes > 1_000_000_000 { - format!("{:.2} GB", bytes as f64 / 1_000_000_000.0) - } else if bytes > 1_000_000 { - format!("{:.2} MB", bytes as f64 / 1_000_000.0) - } else if bytes > 1_000 { - format!("{:.2} KB", bytes as f64 / 1_000.0) + const KIB: f64 = 1024.0; + const MIB: f64 = 1024.0 * 1024.0; + const GIB: f64 = 1024.0 * 1024.0 * 1024.0; + + if bytes as f64 >= GIB { + format!("{:.2} GiB", bytes as f64 / GIB) + } else if bytes as f64 >= MIB { + format!("{:.2} MiB", bytes as f64 / MIB) + } else if bytes as f64 >= KIB { + format!("{:.2} KiB", bytes as f64 / KIB) } else { format!("{} B", bytes) } @@ -73,54 +77,54 @@ mod tests { assert_eq!(format_size(0), "0 B"); assert_eq!(format_size(1), "1 B"); assert_eq!(format_size(999), "999 B"); - assert_eq!(format_size(1000), "1000 B"); + assert_eq!(format_size(1023), "1023 B"); } #[test] fn test_format_size_kilobytes() { - assert_eq!(format_size(1_001), "1.00 KB"); - assert_eq!(format_size(1_500), "1.50 KB"); - assert_eq!(format_size(10_000), "10.00 KB"); - assert_eq!(format_size(999_999), "1000.00 KB"); + assert_eq!(format_size(1024), "1.00 KiB"); + assert_eq!(format_size(1536), "1.50 KiB"); + assert_eq!(format_size(10240), "10.00 KiB"); + assert_eq!(format_size(1_048_575), "1024.00 KiB"); } #[test] fn test_format_size_megabytes() { - assert_eq!(format_size(1_000_001), "1.00 MB"); - assert_eq!(format_size(1_500_000), "1.50 MB"); - assert_eq!(format_size(10_000_000), "10.00 MB"); - assert_eq!(format_size(500_000_000), "500.00 MB"); + assert_eq!(format_size(1_048_576), "1.00 MiB"); + assert_eq!(format_size(1_572_864), "1.50 MiB"); + assert_eq!(format_size(10_485_760), "10.00 MiB"); + assert_eq!(format_size(524_288_000), "500.00 MiB"); } #[test] fn test_format_size_gigabytes() { - assert_eq!(format_size(1_000_000_001), "1.00 GB"); - assert_eq!(format_size(1_500_000_000), "1.50 GB"); - assert_eq!(format_size(10_000_000_000), "10.00 GB"); - assert_eq!(format_size(100_000_000_000), "100.00 GB"); + assert_eq!(format_size(1_073_741_824), "1.00 GiB"); + assert_eq!(format_size(1_610_612_736), "1.50 GiB"); + assert_eq!(format_size(10_737_418_240), "10.00 GiB"); + assert_eq!(format_size(107_374_182_400), "100.00 GiB"); } #[test] fn test_format_size_edge_cases() { - // Boundary between KB and MB - assert_eq!(format_size(1_000_000), "1000.00 KB"); - assert_eq!(format_size(1_000_001), "1.00 MB"); + // Boundary between KiB and MiB + assert_eq!(format_size(1_048_575), "1024.00 KiB"); + assert_eq!(format_size(1_048_576), "1.00 MiB"); - // Boundary between MB and GB - assert_eq!(format_size(1_000_000_000), "1000.00 MB"); - assert_eq!(format_size(1_000_000_001), "1.00 GB"); + // Boundary between MiB and GiB + assert_eq!(format_size(1_073_741_823), "1024.00 MiB"); + assert_eq!(format_size(1_073_741_824), "1.00 GiB"); } #[test] fn test_format_size_realistic_model_sizes() { - // Small model (100 MB) - assert_eq!(format_size(104_857_600), "104.86 MB"); + // Small model (100 MiB) + assert_eq!(format_size(104_857_600), "100.00 MiB"); - // Medium model (7 GB) - assert_eq!(format_size(7_516_192_768), "7.52 GB"); + // Medium model (7 GiB) + assert_eq!(format_size(7_516_192_768), "7.00 GiB"); - // Large model (65 GB) - assert_eq!(format_size(69_793_218_560), "69.79 GB"); + // Large model (65 GiB) + assert_eq!(format_size(69_793_218_560), "65.00 GiB"); } #[test] From 4947878f0ca8e07eaf139f6caddc0f39c3fae746 Mon Sep 17 00:00:00 2001 From: Kante Yin Date: Tue, 21 Apr 2026 00:57:09 +0100 Subject: [PATCH 04/11] Reuse the model cache to avoid duplicate download (#19) * polish the format of the ls command Signed-off-by: kerthcet * Have a progress manager Signed-off-by: kerthcet * Reuse caches Signed-off-by: kerthcet * rename util to utils Signed-off-by: kerthcet * polish the layout of the download progress Signed-off-by: kerthcet * revert change Signed-off-by: kerthcet * add make format Signed-off-by: kerthcet --------- Signed-off-by: kerthcet --- Cargo.lock | 38 +---------- Cargo.toml | 2 +- Makefile | 4 ++ src/cli/commands.rs | 6 +- src/downloader/huggingface.rs | 111 +++++++++++++++++---------------- src/downloader/mod.rs | 1 + src/downloader/progress.rs | 109 ++++++++++++++++++++++++++++++++ src/main.rs | 4 +- src/registry/model_registry.rs | 4 +- src/system/system_info.rs | 4 +- src/util/file.rs | 32 ---------- src/utils/file.rs | 95 ++++++++++++++++++++++++++++ src/{util => utils}/format.rs | 60 ++++++++++++++++++ src/{util => utils}/mod.rs | 0 14 files changed, 338 insertions(+), 132 deletions(-) create mode 100644 src/downloader/progress.rs delete mode 100644 src/util/file.rs create mode 100644 src/utils/file.rs rename src/{util => utils}/format.rs (76%) rename src/{util => utils}/mod.rs (100%) diff --git a/Cargo.lock b/Cargo.lock index 6554c89..7974fe6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -232,19 +232,6 @@ dependencies = [ "windows-sys 0.59.0", ] -[[package]] -name = "console" -version = "0.15.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea3c6ecd8059b57859df5c69830340ed3c41d30e3da0c1cbed90a96ac853041b" -dependencies = [ - "encode_unicode", - "libc", - "once_cell", - "unicode-width 0.2.0", - "windows-sys 0.59.0", -] - [[package]] name = "console" version = "0.16.3" @@ -714,7 +701,7 @@ dependencies = [ "dirs", "futures", "http", - "indicatif 0.18.4", + "indicatif", "libc", "log", "native-tls", @@ -1020,26 +1007,13 @@ dependencies = [ "hashbrown", ] -[[package]] -name = "indicatif" -version = "0.17.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "183b3088984b400f4cfac3620d5e076c84da5364016b4f49473de574b2586235" -dependencies = [ - "console 0.15.10", - "number_prefix", - "portable-atomic", - "unicode-width 0.2.0", - "web-time", -] - [[package]] name = "indicatif" version = "0.18.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "25470f23803092da7d239834776d653104d551bc4d7eacaf31e6837854b8e9eb" dependencies = [ - "console 0.16.3", + "console", "portable-atomic", "unicode-width 0.2.0", "unit-prefix", @@ -1225,12 +1199,6 @@ dependencies = [ "libc", ] -[[package]] -name = "number_prefix" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" - [[package]] name = "object" version = "0.36.7" @@ -1406,7 +1374,7 @@ dependencies = [ "dirs", "env_logger", "hf-hub", - "indicatif 0.17.11", + "indicatif", "log", "prettytable-rs", "reqwest", diff --git a/Cargo.toml b/Cargo.toml index eab4ff7..e05dce3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,7 @@ serde = { version = "1.0", features = ["derive"] } serde_derive = "1.0" env_logger = "0.11.6" log = "0.4.26" -indicatif = "0.17.11" +indicatif = "0.18" dirs = "6.0.0" hf-hub = { version = "0.5.0", features = ["tokio"] } colored = "2.1" diff --git a/Makefile b/Makefile index 61ed3f8..0e72459 100644 --- a/Makefile +++ b/Makefile @@ -7,3 +7,7 @@ test: lint: cargo fmt --all -- --check cargo clippy --all-targets --all-features -- -D warnings + +format: + cargo fmt --all + cargo clippy --fix --allow-dirty \ No newline at end of file diff --git a/src/cli/commands.rs b/src/cli/commands.rs index 8a04b59..128d137 100644 --- a/src/cli/commands.rs +++ b/src/cli/commands.rs @@ -5,7 +5,7 @@ use crate::downloader::downloader::Downloader; use crate::downloader::huggingface::HuggingFaceDownloader; use crate::registry::model_registry::ModelRegistry; use crate::system::system_info::SystemInfo; -use crate::util::format::{format_size, format_time_ago}; +use crate::utils::format::{format_size_decimal, format_time_ago}; #[derive(Parser)] #[command(name = "PUMA")] @@ -92,7 +92,7 @@ pub async fn run(cli: Cli) { table.add_row(row!["MODEL", "PROVIDER", "REVISION", "SIZE", "CREATED"]); for model in models { - let size_str = format_size(model.size); + let size_str = format_size_decimal(model.size); let revision_short = if model.revision.len() > 8 { &model.revision[..8] @@ -118,7 +118,7 @@ pub async fn run(cli: Cli) { Provider::Huggingface => { let downloader = HuggingFaceDownloader::new(); if let Err(e) = downloader.download_model(&args.model).await { - eprintln!("Error downloading model: {}", e); + eprintln!("❌ Error downloading model: {}", e); std::process::exit(1); } } diff --git a/src/downloader/huggingface.rs b/src/downloader/huggingface.rs index baaa7cd..0e16a3b 100644 --- a/src/downloader/huggingface.rs +++ b/src/downloader/huggingface.rs @@ -1,34 +1,31 @@ use colored::Colorize; -use log::{debug, info}; -use std::sync::atomic::{AtomicU64, Ordering}; -use std::sync::Arc; +use log::debug; use hf_hub::api::tokio::{ApiBuilder, Progress}; -use indicatif::{MultiProgress, ProgressBar, ProgressStyle}; use crate::downloader::downloader::{DownloadError, Downloader}; +use crate::downloader::progress::{DownloadProgressManager, FileProgress}; use crate::registry::model_registry::{ModelInfo, ModelRegistry}; -use crate::util::file; +use crate::utils::file::{self, format_model_name}; +/// Adapter to bridge HuggingFace's Progress trait with our FileProgress #[derive(Clone)] -struct FileProgressBar { - pb: ProgressBar, - total_size: Arc, +struct HfProgressAdapter { + progress: FileProgress, } -impl Progress for FileProgressBar { +impl Progress for HfProgressAdapter { async fn init(&mut self, size: usize, _filename: &str) { - self.pb.set_length(size as u64); - self.pb.reset(); - self.pb.tick(); // Force render with correct size - self.total_size.fetch_add(size as u64, Ordering::Relaxed); + self.progress.init(size as u64); } async fn update(&mut self, size: usize) { - self.pb.inc(size as u64); + self.progress.update(size as u64); } - async fn finish(&mut self) {} + async fn finish(&mut self) { + self.progress.finish(); + } } pub struct HuggingFaceDownloader; @@ -49,7 +46,7 @@ impl Downloader for HuggingFaceDownloader { async fn download_model(&self, name: &str) -> Result<(), DownloadError> { let start_time = std::time::Instant::now(); - info!("Downloading model {} from Hugging Face...", name); + debug!("Downloading model {} from Hugging Face...", name); // Use unified PUMA cache directory let cache_dir = file::huggingface_cache_dir(); @@ -65,6 +62,8 @@ impl Downloader for HuggingFaceDownloader { DownloadError::ApiError(format!("Failed to initialize Hugging Face API: {}", e)) })?; + println!("πŸ† pulling manifest"); + // Download the entire model repository using snapshot download let repo = api.model(name.to_string()); @@ -84,9 +83,6 @@ impl Downloader for HuggingFaceDownloader { debug!("Model info for {}: {:?}", name, model_info); - // Create multi-progress for parallel downloads - let multi_progress = Arc::new(MultiProgress::new()); - // Calculate the longest filename for proper alignment let max_filename_len = model_info .siblings @@ -95,54 +91,59 @@ impl Downloader for HuggingFaceDownloader { .max() .unwrap_or(30); - // Progress bar style with block characters (chart-like, not #) - let template = format!( - "{{msg:<{width}}} [{{elapsed_precise}}] {{bar:60.white}} {{bytes}}/{{total_bytes}}", - width = max_filename_len - ); - let style = ProgressStyle::default_bar() - .template(&template) - .unwrap() - .progress_chars("▇▆▅▄▃▂▁ "); + // Create progress manager + let progress_manager = DownloadProgressManager::new(max_filename_len); - // Download all files in parallel - let mut tasks = Vec::new(); + // Calculate cache paths + let model_cache_path = cache_dir.join(format_model_name(name)); let sha = model_info.sha.clone(); - let total_size = Arc::new(AtomicU64::new(0)); + let snapshot_path = model_cache_path.join("snapshots").join(&sha); + + // Process all files in manifest order (cached files show as instantly complete) + let mut tasks = Vec::new(); for sibling in model_info.siblings { let api_clone = api.clone(); let model_name = name.to_string(); let filename = sibling.rfilename.clone(); - let total_size_clone = Arc::clone(&total_size); - - let pb = multi_progress.add(ProgressBar::hidden()); - pb.set_style(style.clone()); - pb.set_message(filename.clone()); + let progress_manager_clone = progress_manager.clone(); + let snapshot_path_clone = snapshot_path.clone(); let task = tokio::spawn(async move { - debug!("Downloading: {}", filename); - let repo = api_clone.model(model_name); - let progress = FileProgressBar { - pb: pb.clone(), - total_size: total_size_clone, - }; - let result = repo.download_with_progress(&filename, progress).await; + // Check if file exists in cache + let cached_file_path = snapshot_path_clone.join(&filename); + if cached_file_path.exists() { + debug!("File {} found in cache, showing as complete", filename); + + // Create progress bar and mark as instantly complete + let mut file_progress = progress_manager_clone.create_file_progress(&filename); + let file_size = cached_file_path.metadata().map(|m| m.len()).unwrap_or(0); + file_progress.init(file_size); + file_progress.update(file_size); + file_progress.finish(); - match &result { - Ok(_) => { - pb.finish(); - } - Err(_) => { - pb.abandon(); - } + return Ok(()); } - result.map_err(|e| { - DownloadError::NetworkError(format!("Failed to download {}: {}", filename, e)) - }) + // File not in cache, download with progress + debug!("Downloading: {}", filename); + let file_progress = progress_manager_clone.create_file_progress(&filename); + let progress = HfProgressAdapter { + progress: file_progress, + }; + + repo.download_with_progress(&filename, progress) + .await + .map_err(|e| { + DownloadError::NetworkError(format!( + "Failed to download {}: {}", + filename, e + )) + })?; + + Ok(()) }); tasks.push(task); @@ -157,8 +158,8 @@ impl Downloader for HuggingFaceDownloader { let elapsed_time = start_time.elapsed(); // Get accumulated size from downloads - let downloaded_size = total_size.load(Ordering::Relaxed); - let model_cache_path = cache_dir.join(format!("models--{}", name.replace("/", "--"))); + let downloaded_size = progress_manager.total_downloaded_bytes(); + let model_cache_path = cache_dir.join(format_model_name(name)); // Register the model let model_info_record = ModelInfo { diff --git a/src/downloader/mod.rs b/src/downloader/mod.rs index 2bdb3f5..39ef068 100644 --- a/src/downloader/mod.rs +++ b/src/downloader/mod.rs @@ -1,3 +1,4 @@ #[allow(clippy::module_inception)] pub mod downloader; pub mod huggingface; +pub mod progress; diff --git a/src/downloader/progress.rs b/src/downloader/progress.rs new file mode 100644 index 0000000..69f3f8d --- /dev/null +++ b/src/downloader/progress.rs @@ -0,0 +1,109 @@ +use indicatif::{MultiProgress, ProgressBar, ProgressStyle}; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Arc; + +/// Manages multi-file download progress tracking +/// +/// # Example +/// ```rust +/// use puma::downloader::progress::DownloadProgressManager; +/// +/// let progress_manager = DownloadProgressManager::new(30); +/// let mut file_progress = progress_manager.create_file_progress("model.bin"); +/// +/// file_progress.init(1024 * 1024); // 1 MB +/// file_progress.update(512 * 1024); // Downloaded 512 KB +/// file_progress.finish(); +/// +/// let total = progress_manager.total_downloaded_bytes(); +/// ``` +#[derive(Clone)] +pub struct DownloadProgressManager { + multi_progress: Arc, + total_size: Arc, + style: ProgressStyle, +} + +impl DownloadProgressManager { + /// Create a new progress manager with aligned file names + pub fn new(max_filename_len: usize) -> Self { + let multi_progress = Arc::new(MultiProgress::new()); + + let template = format!( + "{{msg:<{width}}} [{{elapsed_precise}}] {{bar:60.white}} {{bytes}}/{{total_bytes}}", + width = max_filename_len + ); + let style = ProgressStyle::default_bar() + .template(&template) + .unwrap() + .progress_chars("▇▆▅▄▃▂▁ "); + + Self { + multi_progress, + total_size: Arc::new(AtomicU64::new(0)), + style, + } + } + + /// Create a new progress bar for a file download + pub fn create_file_progress(&self, filename: &str) -> FileProgress { + let pb = self.multi_progress.add(ProgressBar::hidden()); + pb.set_style(self.style.clone()); + pb.set_message(filename.to_string()); + + FileProgress { + pb, + total_size: Arc::clone(&self.total_size), + } + } + + /// Get the total accumulated download size + pub fn total_downloaded_bytes(&self) -> u64 { + self.total_size.load(Ordering::Relaxed) + } +} + +/// Tracks progress for a single file download +#[derive(Clone)] +pub struct FileProgress { + pb: ProgressBar, + total_size: Arc, +} + +impl FileProgress { + /// Initialize progress bar with file size + pub fn init(&mut self, size: u64) { + self.pb.set_length(size); + self.pb.reset(); + self.pb.tick(); + self.total_size.fetch_add(size, Ordering::Relaxed); + } + + /// Update progress with downloaded bytes + pub fn update(&mut self, bytes: u64) { + self.pb.inc(bytes); + } + + /// Mark download as complete + pub fn finish(&mut self) { + self.pb.finish(); + } + + /// Mark download as failed + #[allow(dead_code)] + pub fn abandon(&mut self) { + self.pb.abandon(); + } + + /// Get the inner progress bar (for provider-specific adapters) + #[allow(dead_code)] + pub fn progress_bar(&self) -> &ProgressBar { + &self.pb + } + + /// Get the total size tracker (for provider-specific adapters) + #[allow(dead_code)] + pub fn total_size_tracker(&self) -> Arc { + Arc::clone(&self.total_size) + } +} diff --git a/src/main.rs b/src/main.rs index 02d7f23..57e4eb5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,13 +2,13 @@ mod cli; mod downloader; mod registry; mod system; -mod util; +mod utils; use clap::Parser; use tokio::runtime::Builder; use crate::cli::commands::{run, Cli}; -use crate::util::file; +use crate::utils::file; fn main() { // Initialize logger. diff --git a/src/registry/model_registry.rs b/src/registry/model_registry.rs index ccc1eab..2154d6d 100644 --- a/src/registry/model_registry.rs +++ b/src/registry/model_registry.rs @@ -3,7 +3,7 @@ use serde::{Deserialize, Serialize}; use std::fs; use std::path::PathBuf; -use crate::util::file; +use crate::utils::file; #[derive(Debug, Serialize, Deserialize, Clone)] pub struct ModelInfo { @@ -96,7 +96,7 @@ impl ModelRegistry { self.unregister_model(name)?; println!( - "\n{} {} {}", + "{} {} {}", "βœ“".green().bold(), "Successfully removed model".bright_white(), name.cyan().bold() diff --git a/src/system/system_info.rs b/src/system/system_info.rs index 8abb7df..b716949 100644 --- a/src/system/system_info.rs +++ b/src/system/system_info.rs @@ -5,8 +5,8 @@ use std::path::PathBuf; use sysinfo::System; use crate::registry::model_registry::ModelRegistry; -use crate::util::file; -use crate::util::format::format_size; +use crate::utils::file; +use crate::utils::format::format_size; #[derive(Debug, Serialize, Deserialize)] pub struct SystemInfo { diff --git a/src/util/file.rs b/src/util/file.rs deleted file mode 100644 index 602a1ef..0000000 --- a/src/util/file.rs +++ /dev/null @@ -1,32 +0,0 @@ -use std::fs; -use std::path::PathBuf; - -use dirs::home_dir; - -pub fn create_folder_if_not_exists(folder_path: &PathBuf) -> std::io::Result<()> { - fs::create_dir_all(folder_path)?; - Ok(()) -} - -pub fn root_home() -> PathBuf { - // Allow tests to override PUMA home directory - if let Ok(test_home) = std::env::var("PUMA_HOME") { - PathBuf::from(test_home) - } else { - let home = home_dir().expect("Failed to get home directory"); - home.join(".puma") - } -} - -pub fn cache_dir() -> PathBuf { - root_home().join("cache") -} - -pub fn huggingface_cache_dir() -> PathBuf { - cache_dir().join("huggingface") -} - -#[allow(dead_code)] -pub fn modelscope_cache_dir() -> PathBuf { - cache_dir().join("modelscope") -} diff --git a/src/utils/file.rs b/src/utils/file.rs new file mode 100644 index 0000000..53ade50 --- /dev/null +++ b/src/utils/file.rs @@ -0,0 +1,95 @@ +use std::fs; +use std::path::PathBuf; + +use dirs::home_dir; + +pub fn create_folder_if_not_exists(folder_path: &PathBuf) -> std::io::Result<()> { + fs::create_dir_all(folder_path)?; + Ok(()) +} + +pub fn root_home() -> PathBuf { + // Allow tests to override PUMA home directory + if let Ok(test_home) = std::env::var("PUMA_HOME") { + PathBuf::from(test_home) + } else { + let home = home_dir().expect("Failed to get home directory"); + home.join(".puma") + } +} + +pub fn cache_dir() -> PathBuf { + root_home().join("cache") +} + +pub fn huggingface_cache_dir() -> PathBuf { + cache_dir().join("huggingface") +} + +#[allow(dead_code)] +pub fn modelscope_cache_dir() -> PathBuf { + cache_dir().join("modelscope") +} + +/// Format model name for HuggingFace cache directory +/// Converts "owner/model" to "models--owner--model" +pub fn format_model_name(name: &str) -> String { + format!("models--{}", name.replace("/", "--")) +} + +/// List all files recursively in a directory +#[allow(dead_code)] +pub fn list_files_recursive(dir: &std::path::Path) -> std::io::Result> { + let mut files = Vec::new(); + if dir.is_dir() { + for entry in fs::read_dir(dir)? { + let entry = entry?; + let path = entry.path(); + if path.is_dir() { + files.extend(list_files_recursive(&path)?); + } else { + files.push(path); + } + } + } + Ok(files) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_format_model_name_basic() { + assert_eq!(format_model_name("owner/model"), "models--owner--model"); + } + + #[test] + fn test_format_model_name_complex() { + assert_eq!( + format_model_name("Qwen/Qwen3.5-2B"), + "models--Qwen--Qwen3.5-2B" + ); + } + + #[test] + fn test_format_model_name_multiple_slashes() { + assert_eq!( + format_model_name("org/team/model"), + "models--org--team--model" + ); + } + + #[test] + fn test_format_model_name_no_slash() { + assert_eq!(format_model_name("model"), "models--model"); + } + + #[test] + fn test_format_model_name_special_chars() { + assert_eq!( + format_model_name("InftyAI/tiny-random-gpt2"), + "models--InftyAI--tiny-random-gpt2" + ); + } +} diff --git a/src/util/format.rs b/src/utils/format.rs similarity index 76% rename from src/util/format.rs rename to src/utils/format.rs index 70e109b..0e193ab 100644 --- a/src/util/format.rs +++ b/src/utils/format.rs @@ -17,6 +17,23 @@ pub fn format_size(bytes: u64) -> String { } } +/// Format byte size to human-readable format using decimal units (B, KB, MB, GB) +pub fn format_size_decimal(bytes: u64) -> String { + const KB: f64 = 1000.0; + const MB: f64 = 1000.0 * 1000.0; + const GB: f64 = 1000.0 * 1000.0 * 1000.0; + + if bytes as f64 >= GB { + format!("{:.2} GB", bytes as f64 / GB) + } else if bytes as f64 >= MB { + format!("{:.2} MB", bytes as f64 / MB) + } else if bytes as f64 >= KB { + format!("{:.2} KB", bytes as f64 / KB) + } else { + format!("{} B", bytes) + } +} + /// Format RFC3339 timestamp to human-readable relative time (e.g., "2 hours ago") pub fn format_time_ago(timestamp: &str) -> String { // Try to parse as RFC3339 @@ -213,4 +230,47 @@ mod tests { let invalid = "not-a-timestamp"; assert_eq!(format_time_ago(invalid), "not-a-timestamp"); } + + #[test] + fn test_format_size_decimal_bytes() { + assert_eq!(format_size_decimal(0), "0 B"); + assert_eq!(format_size_decimal(1), "1 B"); + assert_eq!(format_size_decimal(999), "999 B"); + } + + #[test] + fn test_format_size_decimal_kilobytes() { + assert_eq!(format_size_decimal(1000), "1.00 KB"); + assert_eq!(format_size_decimal(1500), "1.50 KB"); + assert_eq!(format_size_decimal(10000), "10.00 KB"); + assert_eq!(format_size_decimal(999_999), "1000.00 KB"); + } + + #[test] + fn test_format_size_decimal_megabytes() { + assert_eq!(format_size_decimal(1_000_000), "1.00 MB"); + assert_eq!(format_size_decimal(1_500_000), "1.50 MB"); + assert_eq!(format_size_decimal(10_000_000), "10.00 MB"); + assert_eq!(format_size_decimal(500_000_000), "500.00 MB"); + } + + #[test] + fn test_format_size_decimal_gigabytes() { + assert_eq!(format_size_decimal(1_000_000_000), "1.00 GB"); + assert_eq!(format_size_decimal(1_500_000_000), "1.50 GB"); + assert_eq!(format_size_decimal(10_000_000_000), "10.00 GB"); + assert_eq!(format_size_decimal(100_000_000_000), "100.00 GB"); + } + + #[test] + fn test_format_size_decimal_realistic_model_sizes() { + // Small model (100 MB) + assert_eq!(format_size_decimal(100_000_000), "100.00 MB"); + + // Medium model (7 GB) + assert_eq!(format_size_decimal(7_000_000_000), "7.00 GB"); + + // Large model (65 GB) + assert_eq!(format_size_decimal(65_000_000_000), "65.00 GB"); + } } diff --git a/src/util/mod.rs b/src/utils/mod.rs similarity index 100% rename from src/util/mod.rs rename to src/utils/mod.rs From a300a89f4e12262403809123591d480b4d5b61cd Mon Sep 17 00:00:00 2001 From: Kante Yin Date: Tue, 21 Apr 2026 01:11:27 +0100 Subject: [PATCH 05/11] remove available mem (#22) Signed-off-by: kerthcet --- src/system/system_info.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/system/system_info.rs b/src/system/system_info.rs index b716949..d2e8d9c 100644 --- a/src/system/system_info.rs +++ b/src/system/system_info.rs @@ -15,7 +15,6 @@ pub struct SystemInfo { pub architecture: String, pub cpu_cores: usize, pub total_memory: String, - pub available_memory: String, pub cache_dir: String, pub cache_size: String, pub models_count: usize, @@ -25,7 +24,7 @@ pub struct SystemInfo { impl SystemInfo { pub fn collect() -> Self { let mut sys = System::new_all(); - sys.refresh_all(); + sys.refresh_memory(); let cache_dir = file::cache_dir(); let cache_size = Self::calculate_cache_size(&cache_dir); @@ -39,7 +38,6 @@ impl SystemInfo { architecture: System::cpu_arch().unwrap_or_else(|| "Unknown".to_string()), cpu_cores: sys.cpus().len(), total_memory: format_size(sys.total_memory()), - available_memory: format_size(sys.available_memory()), cache_dir: cache_dir.to_string_lossy().to_string(), cache_size: format_size(cache_size), models_count, @@ -95,7 +93,6 @@ impl SystemInfo { println!(" Architecture: {}", self.architecture); println!(" CPU Cores: {}", self.cpu_cores); println!(" Total Memory: {}", self.total_memory); - println!(" Available Memory: {}", self.available_memory); println!(); println!("PUMA Information:"); println!(" PUMA Version: {}", self.version); From 07213c1c169624ff027f0fac12f8456e62f6fb87 Mon Sep 17 00:00:00 2001 From: Kante Yin Date: Tue, 21 Apr 2026 01:31:13 +0100 Subject: [PATCH 06/11] add speed at the end (#23) * add speed at the end Signed-off-by: kerthcet * fix lint Signed-off-by: kerthcet --------- Signed-off-by: kerthcet --- src/downloader/huggingface.rs | 5 +++-- src/downloader/progress.rs | 26 +++++++++++++++++++++++++- 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/src/downloader/huggingface.rs b/src/downloader/huggingface.rs index 0e16a3b..925cf7b 100644 --- a/src/downloader/huggingface.rs +++ b/src/downloader/huggingface.rs @@ -117,8 +117,9 @@ impl Downloader for HuggingFaceDownloader { if cached_file_path.exists() { debug!("File {} found in cache, showing as complete", filename); - // Create progress bar and mark as instantly complete - let mut file_progress = progress_manager_clone.create_file_progress(&filename); + // Create progress bar for cached file (no speed display) + let mut file_progress = + progress_manager_clone.create_cached_file_progress(&filename); let file_size = cached_file_path.metadata().map(|m| m.len()).unwrap_or(0); file_progress.init(file_size); file_progress.update(file_size); diff --git a/src/downloader/progress.rs b/src/downloader/progress.rs index 69f3f8d..533e364 100644 --- a/src/downloader/progress.rs +++ b/src/downloader/progress.rs @@ -22,6 +22,7 @@ pub struct DownloadProgressManager { multi_progress: Arc, total_size: Arc, style: ProgressStyle, + cached_style: ProgressStyle, } impl DownloadProgressManager { @@ -30,7 +31,7 @@ impl DownloadProgressManager { let multi_progress = Arc::new(MultiProgress::new()); let template = format!( - "{{msg:<{width}}} [{{elapsed_precise}}] {{bar:60.white}} {{bytes}}/{{total_bytes}}", + "{{msg:<{width}}} [{{elapsed_precise}}] {{bar:60.white}} {{bytes}}/{{total_bytes}} {{bytes_per_sec}}", width = max_filename_len ); let style = ProgressStyle::default_bar() @@ -38,10 +39,21 @@ impl DownloadProgressManager { .unwrap() .progress_chars("▇▆▅▄▃▂▁ "); + // Cached file style without speed + let cached_template = format!( + "{{msg:<{width}}} [{{elapsed_precise}}] {{bar:60.white}} {{bytes}}/{{total_bytes}}", + width = max_filename_len + ); + let cached_style = ProgressStyle::default_bar() + .template(&cached_template) + .unwrap() + .progress_chars("▇▆▅▄▃▂▁ "); + Self { multi_progress, total_size: Arc::new(AtomicU64::new(0)), style, + cached_style, } } @@ -57,6 +69,18 @@ impl DownloadProgressManager { } } + /// Create a new progress bar for a cached file (no speed display) + pub fn create_cached_file_progress(&self, filename: &str) -> FileProgress { + let pb = self.multi_progress.add(ProgressBar::hidden()); + pb.set_style(self.cached_style.clone()); + pb.set_message(filename.to_string()); + + FileProgress { + pb, + total_size: Arc::clone(&self.total_size), + } + } + /// Get the total accumulated download size pub fn total_downloaded_bytes(&self) -> u64 { self.total_size.load(Ordering::Relaxed) From dd765e3107c9e82841d0e1fef268161f9eab73a0 Mon Sep 17 00:00:00 2001 From: Kante Yin Date: Tue, 21 Apr 2026 23:24:58 +0100 Subject: [PATCH 07/11] fix: do no register model once cached (#26) Signed-off-by: kerthcet --- src/cli/commands.rs | 4 ++-- src/downloader/huggingface.rs | 18 +++++++++++++----- src/registry/model_registry.rs | 14 +++++++------- 3 files changed, 22 insertions(+), 14 deletions(-) diff --git a/src/cli/commands.rs b/src/cli/commands.rs index 128d137..5615aab 100644 --- a/src/cli/commands.rs +++ b/src/cli/commands.rs @@ -89,7 +89,7 @@ pub async fn run(cli: Cli) { let mut table = Table::new(); table.set_format(*format::consts::FORMAT_CLEAN); - table.add_row(row!["MODEL", "PROVIDER", "REVISION", "SIZE", "CREATED"]); + table.add_row(row!["MODEL", "PROVIDER", "REVISION", "SIZE", "MODIFIED"]); for model in models { let size_str = format_size_decimal(model.size); @@ -100,7 +100,7 @@ pub async fn run(cli: Cli) { &model.revision }; - let created_str = format_time_ago(&model.created_at); + let created_str = format_time_ago(&model.modified_at); table.add_row(row![ model.name, diff --git a/src/downloader/huggingface.rs b/src/downloader/huggingface.rs index 925cf7b..1e09a32 100644 --- a/src/downloader/huggingface.rs +++ b/src/downloader/huggingface.rs @@ -99,6 +99,12 @@ impl Downloader for HuggingFaceDownloader { let sha = model_info.sha.clone(); let snapshot_path = model_cache_path.join("snapshots").join(&sha); + // Check if all files are already cached + let model_totally_cached = model_info + .siblings + .iter() + .all(|sibling| snapshot_path.join(&sibling.rfilename).exists()); + // Process all files in manifest order (cached files show as instantly complete) let mut tasks = Vec::new(); @@ -168,14 +174,16 @@ impl Downloader for HuggingFaceDownloader { provider: "huggingface".to_string(), revision: sha, size: downloaded_size, - created_at: chrono::Local::now().to_rfc3339(), + modified_at: chrono::Local::now().to_rfc3339(), cache_path: model_cache_path.to_string_lossy().to_string(), }; - let registry = ModelRegistry::new(None); - registry - .register_model(model_info_record) - .map_err(|e| DownloadError::ApiError(format!("Failed to register model: {}", e)))?; + if !model_totally_cached { + let registry = ModelRegistry::new(None); + registry + .register_model(model_info_record) + .map_err(|e| DownloadError::ApiError(format!("Failed to register model: {}", e)))?; + } println!( "\n{} {} {} {} {:.2?}", diff --git a/src/registry/model_registry.rs b/src/registry/model_registry.rs index 2154d6d..f52893a 100644 --- a/src/registry/model_registry.rs +++ b/src/registry/model_registry.rs @@ -11,7 +11,7 @@ pub struct ModelInfo { pub provider: String, pub revision: String, pub size: u64, - pub created_at: String, + pub modified_at: String, pub cache_path: String, } @@ -122,7 +122,7 @@ mod tests { provider: "huggingface".to_string(), revision: "abc123".to_string(), size: 1000, - created_at: "2025-01-01T00:00:00Z".to_string(), + modified_at: "2025-01-01T00:00:00Z".to_string(), cache_path: "/tmp/test".to_string(), }; @@ -143,7 +143,7 @@ mod tests { provider: "huggingface".to_string(), revision: "abc123".to_string(), size: 1000, - created_at: "2025-01-01T00:00:00Z".to_string(), + modified_at: "2025-01-01T00:00:00Z".to_string(), cache_path: "/tmp/test".to_string(), }; @@ -164,7 +164,7 @@ mod tests { provider: "huggingface".to_string(), revision: "abc123".to_string(), size: 1000, - created_at: "2025-01-01T00:00:00Z".to_string(), + modified_at: "2025-01-01T00:00:00Z".to_string(), cache_path: "/tmp/test".to_string(), }; @@ -198,7 +198,7 @@ mod tests { provider: "huggingface".to_string(), revision: "abc123".to_string(), size: 1000, - created_at: "2025-01-01T00:00:00Z".to_string(), + modified_at: "2025-01-01T00:00:00Z".to_string(), cache_path: "/tmp/test".to_string(), }; @@ -209,7 +209,7 @@ mod tests { provider: "huggingface".to_string(), revision: "def456".to_string(), size: 2000, - created_at: "2025-01-02T00:00:00Z".to_string(), + modified_at: "2025-01-02T00:00:00Z".to_string(), cache_path: "/tmp/test2".to_string(), }; @@ -236,7 +236,7 @@ mod tests { provider: "huggingface".to_string(), revision: "abc123".to_string(), size: 1000, - created_at: "2025-01-01T00:00:00Z".to_string(), + modified_at: "2025-01-01T00:00:00Z".to_string(), cache_path: cache_dir.to_string_lossy().to_string(), }; From 2bf6602edfefbd29addea5862e499f27309823e7 Mon Sep 17 00:00:00 2001 From: Kante Yin Date: Tue, 21 Apr 2026 23:52:15 +0100 Subject: [PATCH 08/11] Support GPU detect (#27) * support GPU detect Signed-off-by: kerthcet * fix lint Signed-off-by: kerthcet --------- Signed-off-by: kerthcet --- src/system/system_info.rs | 179 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 179 insertions(+) diff --git a/src/system/system_info.rs b/src/system/system_info.rs index d2e8d9c..9f8cc0d 100644 --- a/src/system/system_info.rs +++ b/src/system/system_info.rs @@ -2,6 +2,7 @@ use serde::{Deserialize, Serialize}; use std::fs; use std::os::unix::fs::MetadataExt; use std::path::PathBuf; +use std::process::Command; use sysinfo::System; use crate::registry::model_registry::ModelRegistry; @@ -15,12 +16,20 @@ pub struct SystemInfo { pub architecture: String, pub cpu_cores: usize, pub total_memory: String, + pub gpu_info: Vec, pub cache_dir: String, pub cache_size: String, pub models_count: usize, pub running_models: usize, } +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct GpuInfo { + pub name: String, + pub backend: String, // "CUDA", "Metal", "ROCm", or "Unknown" + pub memory: Option, +} + impl SystemInfo { pub fn collect() -> Self { let mut sys = System::new_all(); @@ -32,12 +41,15 @@ impl SystemInfo { let registry = ModelRegistry::new(None); let models_count = registry.load_models().unwrap_or_default().len(); + let gpu_info = Self::detect_gpus(); + SystemInfo { version: env!("CARGO_PKG_VERSION").to_string(), os: System::name().unwrap_or_else(|| "Unknown".to_string()), architecture: System::cpu_arch().unwrap_or_else(|| "Unknown".to_string()), cpu_cores: sys.cpus().len(), total_memory: format_size(sys.total_memory()), + gpu_info, cache_dir: cache_dir.to_string_lossy().to_string(), cache_size: format_size(cache_size), models_count, @@ -45,6 +57,157 @@ impl SystemInfo { } } + fn detect_gpus() -> Vec { + let mut gpus = Vec::new(); + + // Try NVIDIA GPUs first (Linux/Windows) + if let Some(nvidia_gpus) = Self::detect_nvidia_gpus() { + gpus.extend(nvidia_gpus); + } + + // Try Metal (macOS) + if let Some(metal_gpu) = Self::detect_metal_gpu() { + gpus.push(metal_gpu); + } + + // Try AMD ROCm (Linux) + if let Some(amd_gpus) = Self::detect_amd_gpus() { + gpus.extend(amd_gpus); + } + + gpus + } + + fn detect_nvidia_gpus() -> Option> { + let output = Command::new("nvidia-smi") + .args([ + "--query-gpu=name,memory.total", + "--format=csv,noheader,nounits", + ]) + .output() + .ok()?; + + if !output.status.success() { + return None; + } + + let output_str = String::from_utf8(output.stdout).ok()?; + let mut gpus = Vec::new(); + + for line in output_str.lines() { + let parts: Vec<&str> = line.split(',').map(|s| s.trim()).collect(); + if parts.len() >= 2 { + gpus.push(GpuInfo { + name: parts[0].to_string(), + backend: "CUDA".to_string(), + memory: Some(format!("{} MB", parts[1])), + }); + } + } + + if gpus.is_empty() { + None + } else { + Some(gpus) + } + } + + fn detect_metal_gpu() -> Option { + // Check if running on macOS + if !cfg!(target_os = "macos") { + return None; + } + + let output = Command::new("system_profiler") + .arg("SPDisplaysDataType") + .output() + .ok()?; + + if !output.status.success() { + return None; + } + + let output_str = String::from_utf8(output.stdout).ok()?; + let lines: Vec<&str> = output_str.lines().collect(); + + // Find GPU name and cores + let mut gpu_name = None; + let mut core_count = None; + + for (i, line) in lines.iter().enumerate() { + if line.contains("Chipset Model:") { + let parts: Vec<&str> = line.split("Chipset Model:").collect(); + if parts.len() >= 2 { + let name = parts[1].trim(); + if !name.is_empty() { + gpu_name = Some(name.to_string()); + + // Look for core count in the next few lines + for line in lines.iter().skip(i + 1).take(10) { + if line.contains("Total Number of Cores:") { + let core_parts: Vec<&str> = + line.split("Total Number of Cores:").collect(); + if core_parts.len() >= 2 { + core_count = Some(core_parts[1].trim().to_string()); + } + break; + } + } + break; + } + } + } + } + + if let Some(name) = gpu_name { + let memory_str = core_count.map(|cores| format!("{} GPU cores", cores)); + + return Some(GpuInfo { + name, + backend: "Metal".to_string(), + memory: memory_str, + }); + } + + None + } + + fn detect_amd_gpus() -> Option> { + let output = Command::new("rocm-smi") + .arg("--showproductname") + .output() + .ok()?; + + if !output.status.success() { + return None; + } + + let output_str = String::from_utf8(output.stdout).ok()?; + let mut gpus = Vec::new(); + + for line in output_str.lines() { + if line.contains("Card series:") || line.contains("Card model:") { + let parts: Vec<&str> = line.split(':').collect(); + if parts.len() >= 2 { + let name = parts[1].trim().to_string(); + if !name.is_empty() { + gpus.push(GpuInfo { + name, + backend: "ROCm".to_string(), + memory: None, + }); + } + } + } + } + + if gpus.is_empty() { + None + } else { + Some(gpus) + } + } + fn calculate_cache_size(cache_dir: &PathBuf) -> u64 { if !cache_dir.exists() { return 0; @@ -93,6 +256,22 @@ impl SystemInfo { println!(" Architecture: {}", self.architecture); println!(" CPU Cores: {}", self.cpu_cores); println!(" Total Memory: {}", self.total_memory); + + if !self.gpu_info.is_empty() { + for (i, gpu) in self.gpu_info.iter().enumerate() { + if i == 0 { + print!(" GPU: "); + } else { + print!(" "); + } + print!("{} ({})", gpu.name, gpu.backend); + if let Some(ref memory) = gpu.memory { + print!(" - {}", memory); + } + println!(); + } + } + println!(); println!("PUMA Information:"); println!(" PUMA Version: {}", self.version); From 6827d0286f0c7dce21c4aa8bfd2ec6d064ac0978 Mon Sep 17 00:00:00 2001 From: Kante Yin Date: Thu, 23 Apr 2026 00:51:32 +0100 Subject: [PATCH 09/11] update readme.md (#28) Signed-off-by: kerthcet --- README.md | 106 ++++++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 87 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index 7b76557..86ba161 100644 --- a/README.md +++ b/README.md @@ -1,33 +1,84 @@ # PUMA -**PUMA** aims to be a lightweight, high-performance inference engine for local AI. *Play for fun.* +A lightweight, high-performance inference engine for local AI. *Play for fun.* ## Features -- πŸš€ **Model Management** - Download and manage AI models from multiple providers +- **Model Management** - Download and manage AI models from model providers like Hugging Face +- **System Detection** - Automatic GPU detection and system information reporting +- **Local Caching** - Efficient model storage with custom cache directories +- **Multiple Providers** - Support for Hugging Face with ModelScope coming soon -## Quick Start +## Installation -### Install from source +### From Source ```bash make build ``` +The binary will be available as `./puma`. + +## Quick Start + +### 1. Download a Model + +```bash +# From Hugging Face (default) +puma pull InftyAI/tiny-random-gpt2 +``` + +### 2. List Downloaded Models + +```bash +puma ls +``` + +### 3. Check System Information + +```bash +puma info +``` + +Example output: +``` +System Information: + Operating System: Darwin + Architecture: arm64 + CPU Cores: 14 + Total Memory: 36.00 GiB + GPU: Apple M4 Max (Metal) - 32 GPU cores + +PUMA Information: + PUMA Version: 0.0.1 + Cache Directory: ~/.puma/cache + Cache Size: 799.88 MiB + Models: 1 + Running Models: 0 +``` + ## Commands -| Command | Description | -|---------|-------------| -| `pull` | Download a model from a provider | -| `ls` | List local models | -| `ps` | List running models | -| `run` | Create and run a model | -| `stop` | Stop a running model | -| `rm` | Remove a model | -| `info` | Display system-wide information | -| `inspect` | Return detailed information about a model | -| `version` | Show PUMA version | -| `help` | Show help information | +| Command | Status | Description | Example | +|---------|--------|-------------|---------| +| `pull` | βœ… | Download a model from a provider | `puma pull InftyAI/tiny-random-gpt2` | +| `ls` | βœ… | List local models | `puma ls` | +| `ps` | 🚧 | List running models | `puma ps` | +| `run` | 🚧 | Create and run a model | `puma run InftyAI/tiny-random-gpt2` | +| `stop` | 🚧 | Stop a running model | `puma stop ` | +| `rm` | βœ… | Remove a model | `puma rm InftyAI/tiny-random-gpt2` | +| `info` | βœ… | Display system-wide information | `puma info` | +| `inspect` | 🚧 | Return detailed information about a model or service | `puma inspect InftyAI/tiny-random-gpt2` | +| `version` | βœ… | Show PUMA version | `puma version` | +| `help` | βœ… | Show help information | `puma help` | + +## Configuration + +PUMA stores models in `~/.puma/cache` by default. This location is used for all downloaded models and metadata. + +## Supported Providers + +- **Hugging Face** - Full support with custom cache directories ## Development @@ -43,7 +94,24 @@ make build make test ``` -### Supported Providers +### Project Structure + +``` +puma/ +β”œβ”€β”€ src/ +β”‚ β”œβ”€β”€ cli/ # Command-line interface +β”‚ β”œβ”€β”€ downloader/ # Model download logic +β”‚ β”œβ”€β”€ registry/ # Model registry management +β”‚ β”œβ”€β”€ system/ # System detection (CPU, GPU, memory) +β”‚ └── utils/ # Utility functions +β”œβ”€β”€ Cargo.toml # Rust dependencies +└── Makefile # Build commands +``` + +## License + +Apache-2.0 + +## Contributing -- βœ… **Hugging Face** - Full support with custom cache directories -- 🚧 **ModelScope** - Coming soon +[![Star History Chart](https://api.star-history.com/svg?repos=inftyai/puma&type=Date)](https://www.star-history.com/#inftyai/puma&Date) From 6ffd9fbf3c267e5a697d95d8bb3bef509b9ce5f4 Mon Sep 17 00:00:00 2001 From: Kante Yin Date: Fri, 24 Apr 2026 00:54:17 +0100 Subject: [PATCH 10/11] Support inspect command (#29) * add support for inspect Signed-off-by: kerthcet * add support for inspect Signed-off-by: kerthcet * add pull progress bar Signed-off-by: kerthcet * polish the download progress Signed-off-by: kerthcet * reorganize the structure Signed-off-by: kerthcet * optimize the structure Signed-off-by: kerthcet * fix test Signed-off-by: kerthcet * fix lint Signed-off-by: kerthcet --------- Signed-off-by: kerthcet --- README.md | 2 +- src/cli/commands.rs | 69 +++++++++- src/downloader/huggingface.rs | 75 ++++++++--- src/downloader/progress.rs | 13 ++ src/registry/model_registry.rs | 229 +++++++++++++++++++++++++++++++++ src/system/system_info.rs | 1 - src/utils/format.rs | 77 +++++++++++ 7 files changed, 444 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index 86ba161..e2ed14c 100644 --- a/README.md +++ b/README.md @@ -68,7 +68,7 @@ PUMA Information: | `stop` | 🚧 | Stop a running model | `puma stop ` | | `rm` | βœ… | Remove a model | `puma rm InftyAI/tiny-random-gpt2` | | `info` | βœ… | Display system-wide information | `puma info` | -| `inspect` | 🚧 | Return detailed information about a model or service | `puma inspect InftyAI/tiny-random-gpt2` | +| `inspect` | βœ… | Return detailed information about a model or service | `puma inspect InftyAI/tiny-random-gpt2` | | `version` | βœ… | Show PUMA version | `puma version` | | `help` | βœ… | Show help information | `puma help` | diff --git a/src/cli/commands.rs b/src/cli/commands.rs index 5615aab..8535b8b 100644 --- a/src/cli/commands.rs +++ b/src/cli/commands.rs @@ -33,7 +33,7 @@ enum Commands { /// Display system-wide information INFO, /// Return detailed information about a model - INSPECT, + INSPECT(InspectArgs), /// Returns the version of PUMA. VERSION, } @@ -58,6 +58,12 @@ struct RmArgs { model: String, } +#[derive(Parser)] +struct InspectArgs { + /// Model name to inspect (e.g., InftyAI/tiny-random-gpt2) + model: String, +} + #[derive(Debug, Clone, Default, clap::ValueEnum)] pub enum Provider { #[default] @@ -70,7 +76,12 @@ pub async fn run(cli: Cli) { match cli.command { Commands::PS => { let mut table = Table::new(); - table.set_format(*format::consts::FORMAT_CLEAN); + table.set_format( + format::FormatBuilder::new() + .column_separator(' ') + .padding(0, 1) + .build(), + ); table.add_row(row!["NAME", "PROVIDER", "MODEL", "STATUS", "AGE"]); table.add_row(row![ "deepseek-r1", @@ -88,7 +99,12 @@ pub async fn run(cli: Cli) { let models = registry.load_models().unwrap_or_default(); let mut table = Table::new(); - table.set_format(*format::consts::FORMAT_CLEAN); + table.set_format( + format::FormatBuilder::new() + .column_separator(' ') + .padding(0, 1) + .build(), + ); table.add_row(row!["MODEL", "PROVIDER", "REVISION", "SIZE", "MODIFIED"]); for model in models { @@ -163,8 +179,51 @@ pub async fn run(cli: Cli) { info.display(); } - Commands::INSPECT => { - println!("Returning detailed information about model..."); + Commands::INSPECT(args) => { + let registry = ModelRegistry::new(None); + + match registry.get_model(&args.model) { + Ok(Some(model)) => { + println!("Name: {}", model.name); + println!("Kind: Model"); + + println!("Spec:"); + // Architecture section (only if info is available) + if let Some(arch) = &model.arch { + println!(" Architecture:"); + if let Some(model_type) = &arch.model_type { + println!(" Type: {}", model_type); + } + if let Some(classes) = &arch.classes { + println!(" Classes: {}", classes.join(", ")); + } + if let Some(parameters) = &arch.parameters { + println!(" Parameters: {}", parameters); + } + if let Some(context_window) = arch.context_window { + println!(" Context Window: {}", context_window); + } + } + // Registry section + println!(" Registry:"); + println!(" Provider: {}", model.provider); + println!(" Revision: {}", model.revision); + println!(" Size: {}", format_size_decimal(model.size)); + println!( + " Modified: {}", + format_time_ago(&model.modified_at) + ); + println!(" Cache Path: {}", model.cache_path); + } + Ok(None) => { + eprintln!("Model not found: {}", args.model); + std::process::exit(1); + } + Err(e) => { + eprintln!("Failed to load registry: {}", e); + std::process::exit(1); + } + } } Commands::VERSION => { diff --git a/src/downloader/huggingface.rs b/src/downloader/huggingface.rs index 1e09a32..2b44810 100644 --- a/src/downloader/huggingface.rs +++ b/src/downloader/huggingface.rs @@ -2,10 +2,11 @@ use colored::Colorize; use log::debug; use hf_hub::api::tokio::{ApiBuilder, Progress}; +use indicatif::{ProgressBar, ProgressStyle}; use crate::downloader::downloader::{DownloadError, Downloader}; use crate::downloader::progress::{DownloadProgressManager, FileProgress}; -use crate::registry::model_registry::{ModelInfo, ModelRegistry}; +use crate::registry::model_registry::{ModelArchitecture, ModelInfo, ModelRegistry}; use crate::utils::file::{self, format_model_name}; /// Adapter to bridge HuggingFace's Progress trait with our FileProgress @@ -62,7 +63,15 @@ impl Downloader for HuggingFaceDownloader { DownloadError::ApiError(format!("Failed to initialize Hugging Face API: {}", e)) })?; - println!("πŸ† pulling manifest"); + // Create a simple spinner for manifest pulling + let manifest_spinner = ProgressBar::new_spinner(); + manifest_spinner.set_style( + ProgressStyle::default_spinner() + .tick_chars("⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏") + .template("pulling manifest {spinner:.white}") + .unwrap(), + ); + manifest_spinner.enable_steady_tick(std::time::Duration::from_millis(80)); // Download the entire model repository using snapshot download let repo = api.model(name.to_string()); @@ -81,6 +90,10 @@ impl Downloader for HuggingFaceDownloader { } })?; + // Stop manifest spinner and print clean message + manifest_spinner.finish_and_clear(); + println!("pulling manifest"); + debug!("Model info for {}: {:?}", name, model_info); // Calculate the longest filename for proper alignment @@ -91,6 +104,8 @@ impl Downloader for HuggingFaceDownloader { .max() .unwrap_or(30); + // Add extra space for "pulling " prefix + let max_filename_len = max_filename_len + 8; // Create progress manager let progress_manager = DownloadProgressManager::new(max_filename_len); @@ -124,8 +139,9 @@ impl Downloader for HuggingFaceDownloader { debug!("File {} found in cache, showing as complete", filename); // Create progress bar for cached file (no speed display) + let display_name = format!("pulling {}", filename); let mut file_progress = - progress_manager_clone.create_cached_file_progress(&filename); + progress_manager_clone.create_cached_file_progress(&display_name); let file_size = cached_file_path.metadata().map(|m| m.len()).unwrap_or(0); file_progress.init(file_size); file_progress.update(file_size); @@ -136,7 +152,8 @@ impl Downloader for HuggingFaceDownloader { // File not in cache, download with progress debug!("Downloading: {}", filename); - let file_progress = progress_manager_clone.create_file_progress(&filename); + let display_name = format!("pulling {}", filename); + let file_progress = progress_manager_clone.create_file_progress(&display_name); let progress = HfProgressAdapter { progress: file_progress, }; @@ -156,37 +173,65 @@ impl Downloader for HuggingFaceDownloader { tasks.push(task); } + // Give tasks a moment to start and create their progress bars + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Show spinner at the bottom after all progress bars are created (only if not fully cached) + let spinner = if !model_totally_cached { + Some(progress_manager.create_spinner()) + } else { + None + }; + // Wait for all downloads to complete for task in tasks { task.await .map_err(|e| DownloadError::ApiError(format!("Task join error: {}", e)))??; } + // Finish spinner after downloads complete + if let Some(spinner) = &spinner { + spinner.finish_and_clear(); + } + let elapsed_time = start_time.elapsed(); // Get accumulated size from downloads let downloaded_size = progress_manager.total_downloaded_bytes(); let model_cache_path = cache_dir.join(format_model_name(name)); - // Register the model - let model_info_record = ModelInfo { - name: name.to_string(), - provider: "huggingface".to_string(), - revision: sha, - size: downloaded_size, - modified_at: chrono::Local::now().to_rfc3339(), - cache_path: model_cache_path.to_string_lossy().to_string(), - }; - + // Register the model only if not totally cached if !model_totally_cached { + // Extract architecture info from config.json + let config_path = snapshot_path.join("config.json"); + let arch = if config_path.exists() { + std::fs::read_to_string(&config_path) + .ok() + .and_then(|content| serde_json::from_str::(&content).ok()) + .and_then(|config| ModelArchitecture::from_config(&config)) + } else { + None + }; + + let model_info_record = ModelInfo { + name: name.to_string(), + provider: "huggingface".to_string(), + revision: sha, + size: downloaded_size, + modified_at: chrono::Local::now().to_rfc3339(), + cache_path: model_cache_path.to_string_lossy().to_string(), + arch, + }; + let registry = ModelRegistry::new(None); registry .register_model(model_info_record) .map_err(|e| DownloadError::ApiError(format!("Failed to register model: {}", e)))?; } + // Print success message println!( - "\n{} {} {} {} {:.2?}", + "{} {} {} {} {:.2?}", "βœ“".green().bold(), "Successfully downloaded model".bright_white(), name.cyan().bold(), diff --git a/src/downloader/progress.rs b/src/downloader/progress.rs index 533e364..7b3ba32 100644 --- a/src/downloader/progress.rs +++ b/src/downloader/progress.rs @@ -85,6 +85,19 @@ impl DownloadProgressManager { pub fn total_downloaded_bytes(&self) -> u64 { self.total_size.load(Ordering::Relaxed) } + + /// Create a spinner progress bar (for post-download operations) + pub fn create_spinner(&self) -> ProgressBar { + let pb = self.multi_progress.add(ProgressBar::new_spinner()); + pb.set_style( + ProgressStyle::default_spinner() + .tick_chars("⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏") + .template("{spinner} ") + .unwrap(), + ); + pb.enable_steady_tick(std::time::Duration::from_millis(80)); + pb + } } /// Tracks progress for a single file download diff --git a/src/registry/model_registry.rs b/src/registry/model_registry.rs index f52893a..13ea87c 100644 --- a/src/registry/model_registry.rs +++ b/src/registry/model_registry.rs @@ -4,6 +4,93 @@ use std::fs; use std::path::PathBuf; use crate::utils::file; +use crate::utils::format::format_parameters; + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +pub struct ModelArchitecture { + #[serde(skip_serializing_if = "Option::is_none")] + pub model_type: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub classes: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub context_window: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub parameters: Option, +} + +impl ModelArchitecture { + /// Extract model architecture from config.json + pub fn from_config(config: &serde_json::Value) -> Option { + let model_type = config + .get("model_type") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + + let classes = config + .get("architectures") + .and_then(|v| v.as_array()) + .map(|arr| { + arr.iter() + .filter_map(|v| v.as_str().map(|s| s.to_string())) + .collect::>() + }) + .filter(|v| !v.is_empty()); + + let context_window = config + .get("n_positions") + .or_else(|| config.get("max_position_embeddings")) + .or_else(|| config.get("n_ctx")) + .and_then(|v| v.as_u64()) + .map(|v| v as u32); + + let parameters = Self::estimate_parameters(config); + + if model_type.is_some() + || classes.is_some() + || context_window.is_some() + || parameters.is_some() + { + Some(ModelArchitecture { + model_type, + classes, + context_window, + parameters, + }) + } else { + None + } + } + + /// Estimate model parameters from config + fn estimate_parameters(config: &serde_json::Value) -> Option { + let n_layer = config + .get("n_layer") + .or_else(|| config.get("num_hidden_layers")) + .and_then(|v| v.as_u64())?; + + let n_embd = config + .get("n_embd") + .or_else(|| config.get("hidden_size")) + .and_then(|v| v.as_u64())?; + + let vocab_size = config.get("vocab_size").and_then(|v| v.as_u64())?; + + let n_positions = config + .get("n_positions") + .or_else(|| config.get("max_position_embeddings")) + .and_then(|v| v.as_u64()) + .unwrap_or(2048); + + // Rough parameter estimation for transformer models + // Each layer: ~12 * n_embd^2 (attention + FFN) + // Embeddings: vocab_size * n_embd + n_positions * n_embd + let layer_params = 12 * n_layer * n_embd * n_embd; + let embedding_params = vocab_size * n_embd + n_positions * n_embd; + let total_params = layer_params + embedding_params; + + Some(format_parameters(total_params)) + } +} #[derive(Debug, Serialize, Deserialize, Clone)] pub struct ModelInfo { @@ -13,6 +100,8 @@ pub struct ModelInfo { pub size: u64, pub modified_at: String, pub cache_path: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub arch: Option, } pub struct ModelRegistry { @@ -124,6 +213,7 @@ mod tests { size: 1000, modified_at: "2025-01-01T00:00:00Z".to_string(), cache_path: "/tmp/test".to_string(), + arch: None, }; registry.register_model(model.clone()).unwrap(); @@ -145,6 +235,7 @@ mod tests { size: 1000, modified_at: "2025-01-01T00:00:00Z".to_string(), cache_path: "/tmp/test".to_string(), + arch: None, }; registry.register_model(model).unwrap(); @@ -166,6 +257,7 @@ mod tests { size: 1000, modified_at: "2025-01-01T00:00:00Z".to_string(), cache_path: "/tmp/test".to_string(), + arch: None, }; registry.register_model(model).unwrap(); @@ -200,6 +292,7 @@ mod tests { size: 1000, modified_at: "2025-01-01T00:00:00Z".to_string(), cache_path: "/tmp/test".to_string(), + arch: None, }; registry.register_model(model1).unwrap(); @@ -211,6 +304,7 @@ mod tests { size: 2000, modified_at: "2025-01-02T00:00:00Z".to_string(), cache_path: "/tmp/test2".to_string(), + arch: None, }; registry.register_model(model2).unwrap(); @@ -238,6 +332,7 @@ mod tests { size: 1000, modified_at: "2025-01-01T00:00:00Z".to_string(), cache_path: cache_dir.to_string_lossy().to_string(), + arch: None, }; registry.register_model(model).unwrap(); @@ -263,4 +358,138 @@ mod tests { let result = registry.remove_model("nonexistent"); assert!(result.is_ok()); } + + #[test] + fn test_inspect_model_with_full_spec() { + let temp_dir = TempDir::new().unwrap(); + let registry = ModelRegistry::new(Some(temp_dir.path().to_path_buf())); + + let model = ModelInfo { + name: "test/gpt-model".to_string(), + provider: "huggingface".to_string(), + revision: "abc123def456".to_string(), + size: 7_000_000_000, + modified_at: "2025-01-01T00:00:00Z".to_string(), + cache_path: "/tmp/test/gpt".to_string(), + arch: Some(ModelArchitecture { + model_type: Some("gpt2".to_string()), + classes: Some(vec!["GPT2LMHeadModel".to_string()]), + context_window: Some(2048), + parameters: Some("7.00B".to_string()), + }), + }; + + registry.register_model(model).unwrap(); + + let retrieved = registry.get_model("test/gpt-model").unwrap(); + assert!(retrieved.is_some()); + + let model_info = retrieved.unwrap(); + assert_eq!(model_info.name, "test/gpt-model"); + assert_eq!(model_info.provider, "huggingface"); + assert_eq!(model_info.revision, "abc123def456"); + assert_eq!(model_info.size, 7_000_000_000); + + let arch = model_info.arch.unwrap(); + assert_eq!(arch.model_type, Some("gpt2".to_string())); + assert_eq!(arch.classes, Some(vec!["GPT2LMHeadModel".to_string()])); + assert_eq!(arch.context_window, Some(2048)); + assert_eq!(arch.parameters, Some("7.00B".to_string())); + } + + #[test] + fn test_inspect_model_without_spec() { + let temp_dir = TempDir::new().unwrap(); + let registry = ModelRegistry::new(Some(temp_dir.path().to_path_buf())); + + let model = ModelInfo { + name: "test/legacy-model".to_string(), + provider: "huggingface".to_string(), + revision: "legacy123".to_string(), + size: 1_000_000, + modified_at: "2024-01-01T00:00:00Z".to_string(), + cache_path: "/tmp/test/legacy".to_string(), + arch: None, + }; + + registry.register_model(model).unwrap(); + + let retrieved = registry.get_model("test/legacy-model").unwrap(); + assert!(retrieved.is_some()); + + let model_info = retrieved.unwrap(); + assert_eq!(model_info.name, "test/legacy-model"); + assert!(model_info.arch.is_none()); + } + + #[test] + fn test_model_architecture_from_config_gpt2() { + use serde_json::json; + + let config = json!({ + "model_type": "gpt2", + "architectures": ["GPT2LMHeadModel"], + "n_layer": 5, + "n_embd": 32, + "vocab_size": 1000, + "n_positions": 512 + }); + + let arch = ModelArchitecture::from_config(&config); + assert!(arch.is_some()); + + let arch = arch.unwrap(); + assert_eq!(arch.model_type, Some("gpt2".to_string())); + assert_eq!(arch.classes, Some(vec!["GPT2LMHeadModel".to_string()])); + assert_eq!(arch.context_window, Some(512)); + assert_eq!(arch.parameters, Some("109.82K".to_string())); + } + + #[test] + fn test_model_architecture_from_config_bert_style() { + use serde_json::json; + + let config = json!({ + "model_type": "bert", + "num_hidden_layers": 12, + "hidden_size": 768, + "vocab_size": 30000, + "max_position_embeddings": 512 + }); + + let arch = ModelArchitecture::from_config(&config); + assert!(arch.is_some()); + + let arch = arch.unwrap(); + assert_eq!(arch.model_type, Some("bert".to_string())); + assert_eq!(arch.context_window, Some(512)); + assert!(arch.parameters.unwrap().contains("M")); + } + + #[test] + fn test_model_architecture_from_config_partial() { + use serde_json::json; + + let config = json!({ + "model_type": "llama", + "n_ctx": 4096 + }); + + let arch = ModelArchitecture::from_config(&config); + assert!(arch.is_some()); + + let arch = arch.unwrap(); + assert_eq!(arch.model_type, Some("llama".to_string())); + assert_eq!(arch.context_window, Some(4096)); + assert_eq!(arch.parameters, None); + } + + #[test] + fn test_model_architecture_from_config_empty() { + use serde_json::json; + + let config = json!({}); + let arch = ModelArchitecture::from_config(&config); + assert_eq!(arch, None); + } } diff --git a/src/system/system_info.rs b/src/system/system_info.rs index 9f8cc0d..00b49cd 100644 --- a/src/system/system_info.rs +++ b/src/system/system_info.rs @@ -272,7 +272,6 @@ impl SystemInfo { } } - println!(); println!("PUMA Information:"); println!(" PUMA Version: {}", self.version); println!(" Cache Directory: {}", self.cache_dir); diff --git a/src/utils/format.rs b/src/utils/format.rs index 0e193ab..865c9b9 100644 --- a/src/utils/format.rs +++ b/src/utils/format.rs @@ -34,6 +34,23 @@ pub fn format_size_decimal(bytes: u64) -> String { } } +/// Format parameter count to human-readable format (K, M, B) +pub fn format_parameters(count: u64) -> String { + const K: f64 = 1_000.0; + const M: f64 = 1_000_000.0; + const B: f64 = 1_000_000_000.0; + + if count as f64 >= B { + format!("{:.2}B", count as f64 / B) + } else if count as f64 >= M { + format!("{:.2}M", count as f64 / M) + } else if count as f64 >= K { + format!("{:.2}K", count as f64 / K) + } else { + count.to_string() + } +} + /// Format RFC3339 timestamp to human-readable relative time (e.g., "2 hours ago") pub fn format_time_ago(timestamp: &str) -> String { // Try to parse as RFC3339 @@ -273,4 +290,64 @@ mod tests { // Large model (65 GB) assert_eq!(format_size_decimal(65_000_000_000), "65.00 GB"); } + + #[test] + fn test_format_parameters_raw() { + assert_eq!(format_parameters(0), "0"); + assert_eq!(format_parameters(1), "1"); + assert_eq!(format_parameters(999), "999"); + } + + #[test] + fn test_format_parameters_thousands() { + assert_eq!(format_parameters(1_000), "1.00K"); + assert_eq!(format_parameters(1_500), "1.50K"); + assert_eq!(format_parameters(10_000), "10.00K"); + assert_eq!(format_parameters(999_999), "1000.00K"); + } + + #[test] + fn test_format_parameters_millions() { + assert_eq!(format_parameters(1_000_000), "1.00M"); + assert_eq!(format_parameters(1_500_000), "1.50M"); + assert_eq!(format_parameters(7_000_000), "7.00M"); + assert_eq!(format_parameters(350_000_000), "350.00M"); + } + + #[test] + fn test_format_parameters_billions() { + assert_eq!(format_parameters(1_000_000_000), "1.00B"); + assert_eq!(format_parameters(1_500_000_000), "1.50B"); + assert_eq!(format_parameters(7_000_000_000), "7.00B"); + assert_eq!(format_parameters(175_000_000_000), "175.00B"); + } + + #[test] + fn test_format_parameters_realistic_models() { + // Tiny model (109K parameters) + assert_eq!(format_parameters(109_824), "109.82K"); + + // Small model (125M parameters) + assert_eq!(format_parameters(125_000_000), "125.00M"); + + // Medium model (7B parameters) + assert_eq!(format_parameters(7_000_000_000), "7.00B"); + + // Large model (70B parameters) + assert_eq!(format_parameters(70_000_000_000), "70.00B"); + + // Very large model (405B parameters) + assert_eq!(format_parameters(405_000_000_000), "405.00B"); + } + + #[test] + fn test_format_parameters_edge_cases() { + // Boundary between K and M + assert_eq!(format_parameters(999_999), "1000.00K"); + assert_eq!(format_parameters(1_000_000), "1.00M"); + + // Boundary between M and B + assert_eq!(format_parameters(999_999_999), "1000.00M"); + assert_eq!(format_parameters(1_000_000_000), "1.00B"); + } } From 4f146d7e54dd6fe2505678dd27de9656cf040e9a Mon Sep 17 00:00:00 2001 From: kerthcet Date: Fri, 24 Apr 2026 01:17:32 +0100 Subject: [PATCH 11/11] add metadata Signed-off-by: kerthcet --- src/cli/commands.rs | 212 +++++++++++++++++++++++++++++++-- src/downloader/huggingface.rs | 4 +- src/registry/model_registry.rs | 45 +++++-- 3 files changed, 243 insertions(+), 18 deletions(-) diff --git a/src/cli/commands.rs b/src/cli/commands.rs index 8535b8b..085594a 100644 --- a/src/cli/commands.rs +++ b/src/cli/commands.rs @@ -105,7 +105,7 @@ pub async fn run(cli: Cli) { .padding(0, 1) .build(), ); - table.add_row(row!["MODEL", "PROVIDER", "REVISION", "SIZE", "MODIFIED"]); + table.add_row(row!["MODEL", "PROVIDER", "REVISION", "SIZE", "AGE"]); for model in models { let size_str = format_size_decimal(model.size); @@ -116,7 +116,7 @@ pub async fn run(cli: Cli) { &model.revision }; - let created_str = format_time_ago(&model.modified_at); + let created_str = format_time_ago(&model.created_at); table.add_row(row![ model.name, @@ -186,7 +186,9 @@ pub async fn run(cli: Cli) { Ok(Some(model)) => { println!("Name: {}", model.name); println!("Kind: Model"); - + println!("Metadata:"); + println!(" Created: {}", format_time_ago(&model.created_at)); + println!(" Updated: {}", format_time_ago(&model.updated_at)); println!("Spec:"); // Architecture section (only if info is available) if let Some(arch) = &model.arch { @@ -209,10 +211,6 @@ pub async fn run(cli: Cli) { println!(" Provider: {}", model.provider); println!(" Revision: {}", model.revision); println!(" Size: {}", format_size_decimal(model.size)); - println!( - " Modified: {}", - format_time_ago(&model.modified_at) - ); println!(" Cache Path: {}", model.cache_path); } Ok(None) => { @@ -231,3 +229,203 @@ pub async fn run(cli: Cli) { } } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::registry::model_registry::{ModelArchitecture, ModelInfo}; + use tempfile::TempDir; + + #[test] + fn test_ls_command_empty() { + let temp_dir = TempDir::new().unwrap(); + let registry = ModelRegistry::new(Some(temp_dir.path().to_path_buf())); + + let models = registry.load_models().unwrap_or_default(); + assert_eq!(models.len(), 0); + } + + #[test] + fn test_ls_command_with_models() { + let temp_dir = TempDir::new().unwrap(); + let registry = ModelRegistry::new(Some(temp_dir.path().to_path_buf())); + + let model = ModelInfo { + name: "test/model".to_string(), + provider: "huggingface".to_string(), + revision: "abc123def456".to_string(), + size: 1_000_000, + created_at: "2025-01-01T00:00:00Z".to_string(), + updated_at: "2025-01-01T00:00:00Z".to_string(), + cache_path: "/tmp/test".to_string(), + arch: None, + }; + + registry.register_model(model).unwrap(); + + let models = registry.load_models().unwrap(); + assert_eq!(models.len(), 1); + assert_eq!(models[0].name, "test/model"); + assert_eq!(models[0].provider, "huggingface"); + } + + #[test] + fn test_inspect_command_with_metadata() { + let temp_dir = TempDir::new().unwrap(); + let registry = ModelRegistry::new(Some(temp_dir.path().to_path_buf())); + + let model = ModelInfo { + name: "test/gpt-model".to_string(), + provider: "huggingface".to_string(), + revision: "abc123def456".to_string(), + size: 7_000_000_000, + created_at: "2025-01-01T00:00:00Z".to_string(), + updated_at: "2025-01-02T00:00:00Z".to_string(), + cache_path: "/tmp/test/gpt".to_string(), + arch: Some(ModelArchitecture { + model_type: Some("gpt2".to_string()), + classes: Some(vec!["GPT2LMHeadModel".to_string()]), + context_window: Some(2048), + parameters: Some("7.00B".to_string()), + }), + }; + + registry.register_model(model.clone()).unwrap(); + + let retrieved = registry.get_model("test/gpt-model").unwrap(); + assert!(retrieved.is_some()); + + let model_info = retrieved.unwrap(); + assert_eq!(model_info.name, "test/gpt-model"); + assert_eq!(model_info.created_at, "2025-01-01T00:00:00Z"); + assert_eq!(model_info.updated_at, "2025-01-02T00:00:00Z"); + + let arch = model_info.arch.unwrap(); + assert_eq!(arch.model_type, Some("gpt2".to_string())); + assert_eq!(arch.classes, Some(vec!["GPT2LMHeadModel".to_string()])); + assert_eq!(arch.context_window, Some(2048)); + assert_eq!(arch.parameters, Some("7.00B".to_string())); + } + + #[test] + fn test_inspect_command_without_architecture() { + let temp_dir = TempDir::new().unwrap(); + let registry = ModelRegistry::new(Some(temp_dir.path().to_path_buf())); + + let model = ModelInfo { + name: "test/simple-model".to_string(), + provider: "huggingface".to_string(), + revision: "xyz789".to_string(), + size: 500_000, + created_at: "2025-01-01T00:00:00Z".to_string(), + updated_at: "2025-01-01T00:00:00Z".to_string(), + cache_path: "/tmp/test/simple".to_string(), + arch: None, + }; + + registry.register_model(model).unwrap(); + + let retrieved = registry.get_model("test/simple-model").unwrap(); + assert!(retrieved.is_some()); + + let model_info = retrieved.unwrap(); + assert_eq!(model_info.name, "test/simple-model"); + assert!(model_info.arch.is_none()); + } + + #[test] + fn test_rm_command() { + let temp_dir = TempDir::new().unwrap(); + let registry = ModelRegistry::new(Some(temp_dir.path().to_path_buf())); + + let model = ModelInfo { + name: "test/remove-model".to_string(), + provider: "huggingface".to_string(), + revision: "abc123".to_string(), + size: 1000, + created_at: "2025-01-01T00:00:00Z".to_string(), + updated_at: "2025-01-01T00:00:00Z".to_string(), + cache_path: "/tmp/test/remove".to_string(), + arch: None, + }; + + registry.register_model(model).unwrap(); + assert!(registry.get_model("test/remove-model").unwrap().is_some()); + + // Simulate RM command + let result = registry.get_model("test/remove-model"); + assert!(result.is_ok()); + assert!(result.unwrap().is_some()); + } + + #[test] + fn test_rm_command_nonexistent() { + let temp_dir = TempDir::new().unwrap(); + let registry = ModelRegistry::new(Some(temp_dir.path().to_path_buf())); + + let result = registry.get_model("nonexistent/model"); + assert!(result.is_ok()); + assert!(result.unwrap().is_none()); + } + + #[test] + fn test_revision_truncation() { + let long_revision = "abc123def456ghi789jkl012"; + let short = if long_revision.len() > 8 { + &long_revision[..8] + } else { + long_revision + }; + assert_eq!(short, "abc123de"); + + let short_revision = "abc123"; + let short = if short_revision.len() > 8 { + &short_revision[..8] + } else { + short_revision + }; + assert_eq!(short, "abc123"); + } + + #[test] + fn test_metadata_timestamps_differ() { + let temp_dir = TempDir::new().unwrap(); + let registry = ModelRegistry::new(Some(temp_dir.path().to_path_buf())); + + let model = ModelInfo { + name: "test/updated-model".to_string(), + provider: "huggingface".to_string(), + revision: "v1".to_string(), + size: 1000, + created_at: "2025-01-01T00:00:00Z".to_string(), + updated_at: "2025-01-01T00:00:00Z".to_string(), + cache_path: "/tmp/test".to_string(), + arch: None, + }; + + registry.register_model(model).unwrap(); + + // Update the model + let updated_model = ModelInfo { + name: "test/updated-model".to_string(), + provider: "huggingface".to_string(), + revision: "v2".to_string(), + size: 2000, + created_at: "2025-01-05T00:00:00Z".to_string(), + updated_at: "2025-01-05T00:00:00Z".to_string(), + cache_path: "/tmp/test".to_string(), + arch: None, + }; + + registry.register_model(updated_model).unwrap(); + + let result = registry.get_model("test/updated-model").unwrap().unwrap(); + // created_at should remain the same + assert_eq!(result.created_at, "2025-01-01T00:00:00Z"); + // updated_at should be new + assert_eq!(result.updated_at, "2025-01-05T00:00:00Z"); + // Other fields should be updated + assert_eq!(result.revision, "v2"); + assert_eq!(result.size, 2000); + } +} diff --git a/src/downloader/huggingface.rs b/src/downloader/huggingface.rs index 2b44810..dfc0a40 100644 --- a/src/downloader/huggingface.rs +++ b/src/downloader/huggingface.rs @@ -213,12 +213,14 @@ impl Downloader for HuggingFaceDownloader { None }; + let now = chrono::Local::now().to_rfc3339(); let model_info_record = ModelInfo { name: name.to_string(), provider: "huggingface".to_string(), revision: sha, size: downloaded_size, - modified_at: chrono::Local::now().to_rfc3339(), + created_at: now.clone(), + updated_at: now, cache_path: model_cache_path.to_string_lossy().to_string(), arch, }; diff --git a/src/registry/model_registry.rs b/src/registry/model_registry.rs index 13ea87c..3505570 100644 --- a/src/registry/model_registry.rs +++ b/src/registry/model_registry.rs @@ -98,7 +98,8 @@ pub struct ModelInfo { pub provider: String, pub revision: String, pub size: u64, - pub modified_at: String, + pub created_at: String, + pub updated_at: String, pub cache_path: String, #[serde(skip_serializing_if = "Option::is_none")] pub arch: Option, @@ -148,10 +149,22 @@ impl ModelRegistry { pub fn register_model(&self, model: ModelInfo) -> Result<(), std::io::Error> { let mut models = self.load_models()?; + // Check if model already exists to preserve created_at + let existing_created_at = models + .iter() + .find(|m| m.name == model.name) + .map(|m| m.created_at.clone()); + // Remove existing model with same name if exists models.retain(|m| m.name != model.name); - models.push(model); + // Use existing created_at if this is an update, otherwise use the provided one + let mut final_model = model; + if let Some(created_at) = existing_created_at { + final_model.created_at = created_at; + } + + models.push(final_model); self.save_models(&models)?; Ok(()) @@ -211,7 +224,8 @@ mod tests { provider: "huggingface".to_string(), revision: "abc123".to_string(), size: 1000, - modified_at: "2025-01-01T00:00:00Z".to_string(), + created_at: "2025-01-01T00:00:00Z".to_string(), + updated_at: "2025-01-01T00:00:00Z".to_string(), cache_path: "/tmp/test".to_string(), arch: None, }; @@ -233,7 +247,8 @@ mod tests { provider: "huggingface".to_string(), revision: "abc123".to_string(), size: 1000, - modified_at: "2025-01-01T00:00:00Z".to_string(), + created_at: "2025-01-01T00:00:00Z".to_string(), + updated_at: "2025-01-01T00:00:00Z".to_string(), cache_path: "/tmp/test".to_string(), arch: None, }; @@ -255,7 +270,8 @@ mod tests { provider: "huggingface".to_string(), revision: "abc123".to_string(), size: 1000, - modified_at: "2025-01-01T00:00:00Z".to_string(), + created_at: "2025-01-01T00:00:00Z".to_string(), + updated_at: "2025-01-01T00:00:00Z".to_string(), cache_path: "/tmp/test".to_string(), arch: None, }; @@ -290,7 +306,8 @@ mod tests { provider: "huggingface".to_string(), revision: "abc123".to_string(), size: 1000, - modified_at: "2025-01-01T00:00:00Z".to_string(), + created_at: "2025-01-01T00:00:00Z".to_string(), + updated_at: "2025-01-01T00:00:00Z".to_string(), cache_path: "/tmp/test".to_string(), arch: None, }; @@ -302,7 +319,8 @@ mod tests { provider: "huggingface".to_string(), revision: "def456".to_string(), size: 2000, - modified_at: "2025-01-02T00:00:00Z".to_string(), + created_at: "2025-01-02T00:00:00Z".to_string(), + updated_at: "2025-01-02T00:00:00Z".to_string(), cache_path: "/tmp/test2".to_string(), arch: None, }; @@ -313,6 +331,10 @@ mod tests { assert_eq!(models.len(), 1); assert_eq!(models[0].revision, "def456"); assert_eq!(models[0].size, 2000); + // created_at should be preserved from model1 + assert_eq!(models[0].created_at, "2025-01-01T00:00:00Z"); + // updated_at should be from model2 + assert_eq!(models[0].updated_at, "2025-01-02T00:00:00Z"); } #[test] @@ -330,7 +352,8 @@ mod tests { provider: "huggingface".to_string(), revision: "abc123".to_string(), size: 1000, - modified_at: "2025-01-01T00:00:00Z".to_string(), + created_at: "2025-01-01T00:00:00Z".to_string(), + updated_at: "2025-01-01T00:00:00Z".to_string(), cache_path: cache_dir.to_string_lossy().to_string(), arch: None, }; @@ -369,7 +392,8 @@ mod tests { provider: "huggingface".to_string(), revision: "abc123def456".to_string(), size: 7_000_000_000, - modified_at: "2025-01-01T00:00:00Z".to_string(), + created_at: "2025-01-01T00:00:00Z".to_string(), + updated_at: "2025-01-01T00:00:00Z".to_string(), cache_path: "/tmp/test/gpt".to_string(), arch: Some(ModelArchitecture { model_type: Some("gpt2".to_string()), @@ -407,7 +431,8 @@ mod tests { provider: "huggingface".to_string(), revision: "legacy123".to_string(), size: 1_000_000, - modified_at: "2024-01-01T00:00:00Z".to_string(), + created_at: "2024-01-01T00:00:00Z".to_string(), + updated_at: "2024-01-01T00:00:00Z".to_string(), cache_path: "/tmp/test/legacy".to_string(), arch: None, };