Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion compression/python/compression_clif_aux.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,9 @@ SbsWriter::SbsWriter(const std::string& path)
: impl_(HWY_DYNAMIC_DISPATCH(NewSbsWriter)(path)) {}

SbsReader::SbsReader(const std::string& path)
: reader_(Path(path)), model_(reader_) {}
: reader_(Path(path)),
model_(reader_, Path(), /*wrapping=*/Tristate::kDefault,
/*allow_unknown_config=*/Tristate::kTrue) {}

} // namespace gcpp
#endif // HWY_ONCE
34 changes: 19 additions & 15 deletions gemma/model_store.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <charconv>
#include <cstdlib>
#include <cstring> // strcmp
#include <optional>
#include <string>
#include <system_error> // std::errc // NOLINT

Expand Down Expand Up @@ -250,31 +251,34 @@ static int DeduceLayerTypes(const BlobReader& reader) {
// `wrapping_override` is forwarded from the command line. For pre-2025 files
// without `ModelConfig`, it is the only way to force PT.
static ModelConfig ReadOrDeduceConfig(BlobReader& reader,
Tristate wrapping_override) {
Tristate wrapping_override,
Tristate allow_unknown_config) {
const TypePrefix type_prefix(reader.Keys(), reader);
Type deduced_weight = Type::kUnknown;
if (type_prefix.HasPrefixes()) {
deduced_weight = type_prefix.DeduceWeightType();
type_prefix.PrintTypeBytes();
}

// Always deduce so we can verify it against the config we read.
const size_t layers = DeduceNumLayers(reader.Keys());
const int layer_types = DeduceLayerTypes(reader);
const Model deduced_model =
DeduceModel(reader.blob_path(), layers, layer_types);
// Optionally deduce so we can verify it against the config we read.
std::optional<Model> deduced_model;
const bool has_config = reader.Find(kConfigName);
if (!has_config || allow_unknown_config != Tristate::kTrue) {
const size_t layers = DeduceNumLayers(reader.Keys());
const int layer_types = DeduceLayerTypes(reader);
deduced_model = DeduceModel(reader.blob_path(), layers, layer_types);
}

ModelConfig config;
// Check first to prevent `CallWithSpan` from printing a warning.
if (reader.Find(kConfigName)) {
if (has_config) {
HWY_ASSERT(reader.CallWithSpan<uint32_t>(
kConfigName, [&config](const SerializedSpan serialized) {
const IFields::ReadResult result = config.Read(serialized, 0);
WarnIfExtra(result, kConfigName);
HWY_ASSERT_M(result.pos != 0, "Error deserializing config");
}));

HWY_ASSERT(config.model != Model::UNKNOWN);
HWY_ASSERT(config.wrapping != PromptWrapping::kSentinel);
HWY_ASSERT(config.weight != Type::kUnknown);
for (const LayerConfig& layer_config : config.layer_configs) {
Expand All @@ -285,18 +289,19 @@ static ModelConfig ReadOrDeduceConfig(BlobReader& reader,

// We trust the deserialized config, but checking helps to validate the
// deduction, which we rely on below for pre-2025 files.
if (config.model != deduced_model) {
if (deduced_model.has_value() && config.model != *deduced_model) {
const std::string suffix = WrappingSuffix(config.wrapping);
HWY_WARN("Detected model %s does not match config %s.",
(std::string(ModelPrefix(deduced_model)) + suffix).c_str(),
(std::string(ModelPrefix(*deduced_model)) + suffix).c_str(),
(std::string(ModelPrefix(config.model)) + suffix).c_str());
}
return config;
}

// Pre-2025 format: no config, rely on deduction plus `wrapping_override`.
return ModelConfig(deduced_model, deduced_weight,
ChooseWrapping(deduced_model, wrapping_override));
HWY_ASSERT(deduced_model.has_value());
return ModelConfig(*deduced_model, deduced_weight,
ChooseWrapping(*deduced_model, wrapping_override));
}

static std::vector<float> ReadScales(BlobReader& reader,
Expand Down Expand Up @@ -387,8 +392,8 @@ void ModelStore::CreateMatPtrs(BlobReader& reader) {
}

ModelStore::ModelStore(BlobReader& reader, const Path& tokenizer_path,
Tristate wrapping)
: config_(ReadOrDeduceConfig(reader, wrapping)),
Tristate wrapping, Tristate allow_unknown_config)
: config_(ReadOrDeduceConfig(reader, wrapping, allow_unknown_config)),
tokenizer_(ReadTokenizer(reader, tokenizer_path)) {
if (!ReadMatPtrs(reader)) { // Pre-2025 format.
CreateMatPtrs(reader);
Expand Down Expand Up @@ -460,7 +465,6 @@ static void AddBlob(const char* name, const std::vector<uint32_t>& data,
void WriteSingleFile(const ModelConfig& config, const GemmaTokenizer& tokenizer,
const std::vector<uint32_t>& serialized_mat_ptrs,
BlobWriter& writer) {
HWY_ASSERT(config.model != Model::UNKNOWN);
HWY_ASSERT(config.weight != Type::kUnknown);
HWY_ASSERT(config.wrapping != PromptWrapping::kSentinel);
const std::vector<uint32_t> serialized_config = config.Write();
Expand Down
10 changes: 5 additions & 5 deletions gemma/model_store.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,19 +51,19 @@ class Gemma;
// both, but only write single-file format.
class ModelStore {
public:
// Reads from file(s) or aborts on error. The latter two arguments are only
// used for pre-2025 files.
// Reads from file(s) or aborts on error. `tokenizer_path` and `wrapping` are
// only used for pre-2025 files. `allow_unknown_config` allows loading models
// with unknown config, if the config is present in the file.
ModelStore(BlobReader& reader, const Path& tokenizer_path = Path(),
Tristate wrapping = Tristate::kDefault);
Tristate wrapping = Tristate::kDefault,
Tristate allow_unknown_config = Tristate::kDefault);
~ModelStore();

const ModelConfig& Config() const {
HWY_ASSERT(config_.model != Model::UNKNOWN);
return config_;
}

ModelConfig& MutableConfig() {
HWY_ASSERT(config_.model != Model::UNKNOWN);
return config_;
}

Expand Down
Loading