diff --git a/compression/python/compression_clif_aux.cc b/compression/python/compression_clif_aux.cc index 3568ad34..f4502a0c 100644 --- a/compression/python/compression_clif_aux.cc +++ b/compression/python/compression_clif_aux.cc @@ -148,7 +148,9 @@ SbsWriter::SbsWriter(const std::string& path) : impl_(HWY_DYNAMIC_DISPATCH(NewSbsWriter)(path)) {} SbsReader::SbsReader(const std::string& path) - : reader_(Path(path)), model_(reader_) {} + : reader_(Path(path)), + model_(reader_, Path(), /*wrapping=*/Tristate::kDefault, + /*allow_unknown_config=*/Tristate::kTrue) {} } // namespace gcpp #endif // HWY_ONCE diff --git a/gemma/model_store.cc b/gemma/model_store.cc index 76f0c754..0adee85e 100644 --- a/gemma/model_store.cc +++ b/gemma/model_store.cc @@ -23,6 +23,7 @@ #include #include #include // strcmp +#include #include #include // std::errc // NOLINT @@ -250,7 +251,8 @@ static int DeduceLayerTypes(const BlobReader& reader) { // `wrapping_override` is forwarded from the command line. For pre-2025 files // without `ModelConfig`, it is the only way to force PT. static ModelConfig ReadOrDeduceConfig(BlobReader& reader, - Tristate wrapping_override) { + Tristate wrapping_override, + Tristate allow_unknown_config) { const TypePrefix type_prefix(reader.Keys(), reader); Type deduced_weight = Type::kUnknown; if (type_prefix.HasPrefixes()) { @@ -258,15 +260,18 @@ static ModelConfig ReadOrDeduceConfig(BlobReader& reader, type_prefix.PrintTypeBytes(); } - // Always deduce so we can verify it against the config we read. - const size_t layers = DeduceNumLayers(reader.Keys()); - const int layer_types = DeduceLayerTypes(reader); - const Model deduced_model = - DeduceModel(reader.blob_path(), layers, layer_types); + // Optionally deduce so we can verify it against the config we read. + std::optional deduced_model; + const bool has_config = reader.Find(kConfigName); + if (!has_config || allow_unknown_config != Tristate::kTrue) { + const size_t layers = DeduceNumLayers(reader.Keys()); + const int layer_types = DeduceLayerTypes(reader); + deduced_model = DeduceModel(reader.blob_path(), layers, layer_types); + } ModelConfig config; // Check first to prevent `CallWithSpan` from printing a warning. - if (reader.Find(kConfigName)) { + if (has_config) { HWY_ASSERT(reader.CallWithSpan( kConfigName, [&config](const SerializedSpan serialized) { const IFields::ReadResult result = config.Read(serialized, 0); @@ -274,7 +279,6 @@ static ModelConfig ReadOrDeduceConfig(BlobReader& reader, HWY_ASSERT_M(result.pos != 0, "Error deserializing config"); })); - HWY_ASSERT(config.model != Model::UNKNOWN); HWY_ASSERT(config.wrapping != PromptWrapping::kSentinel); HWY_ASSERT(config.weight != Type::kUnknown); for (const LayerConfig& layer_config : config.layer_configs) { @@ -285,18 +289,19 @@ static ModelConfig ReadOrDeduceConfig(BlobReader& reader, // We trust the deserialized config, but checking helps to validate the // deduction, which we rely on below for pre-2025 files. - if (config.model != deduced_model) { + if (deduced_model.has_value() && config.model != *deduced_model) { const std::string suffix = WrappingSuffix(config.wrapping); HWY_WARN("Detected model %s does not match config %s.", - (std::string(ModelPrefix(deduced_model)) + suffix).c_str(), + (std::string(ModelPrefix(*deduced_model)) + suffix).c_str(), (std::string(ModelPrefix(config.model)) + suffix).c_str()); } return config; } // Pre-2025 format: no config, rely on deduction plus `wrapping_override`. - return ModelConfig(deduced_model, deduced_weight, - ChooseWrapping(deduced_model, wrapping_override)); + HWY_ASSERT(deduced_model.has_value()); + return ModelConfig(*deduced_model, deduced_weight, + ChooseWrapping(*deduced_model, wrapping_override)); } static std::vector ReadScales(BlobReader& reader, @@ -387,8 +392,8 @@ void ModelStore::CreateMatPtrs(BlobReader& reader) { } ModelStore::ModelStore(BlobReader& reader, const Path& tokenizer_path, - Tristate wrapping) - : config_(ReadOrDeduceConfig(reader, wrapping)), + Tristate wrapping, Tristate allow_unknown_config) + : config_(ReadOrDeduceConfig(reader, wrapping, allow_unknown_config)), tokenizer_(ReadTokenizer(reader, tokenizer_path)) { if (!ReadMatPtrs(reader)) { // Pre-2025 format. CreateMatPtrs(reader); @@ -460,7 +465,6 @@ static void AddBlob(const char* name, const std::vector& data, void WriteSingleFile(const ModelConfig& config, const GemmaTokenizer& tokenizer, const std::vector& serialized_mat_ptrs, BlobWriter& writer) { - HWY_ASSERT(config.model != Model::UNKNOWN); HWY_ASSERT(config.weight != Type::kUnknown); HWY_ASSERT(config.wrapping != PromptWrapping::kSentinel); const std::vector serialized_config = config.Write(); diff --git a/gemma/model_store.h b/gemma/model_store.h index 506fb77c..76be557b 100644 --- a/gemma/model_store.h +++ b/gemma/model_store.h @@ -51,19 +51,19 @@ class Gemma; // both, but only write single-file format. class ModelStore { public: - // Reads from file(s) or aborts on error. The latter two arguments are only - // used for pre-2025 files. + // Reads from file(s) or aborts on error. `tokenizer_path` and `wrapping` are + // only used for pre-2025 files. `allow_unknown_config` allows loading models + // with unknown config, if the config is present in the file. ModelStore(BlobReader& reader, const Path& tokenizer_path = Path(), - Tristate wrapping = Tristate::kDefault); + Tristate wrapping = Tristate::kDefault, + Tristate allow_unknown_config = Tristate::kDefault); ~ModelStore(); const ModelConfig& Config() const { - HWY_ASSERT(config_.model != Model::UNKNOWN); return config_; } ModelConfig& MutableConfig() { - HWY_ASSERT(config_.model != Model::UNKNOWN); return config_; }