Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix SD VAE switch error after model reuse #12685

Merged
merged 5 commits into from
Aug 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 11 additions & 2 deletions modules/sd_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,8 +485,12 @@ def get_sd_model(self):

return self.sd_model

def set_sd_model(self, v):
def set_sd_model(self, v, already_loaded=False):
self.sd_model = v
if already_loaded:
sd_vae.base_vae = getattr(v, "base_vae", None)
sd_vae.loaded_vae_file = getattr(v, "loaded_vae_file", None)
sd_vae.checkpoint_info = v.sd_checkpoint_info

try:
self.loaded_sd_models.remove(v)
Expand Down Expand Up @@ -660,13 +664,14 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
send_model_to_device(already_loaded)
timer.record("send model to device")

model_data.set_sd_model(already_loaded)
model_data.set_sd_model(already_loaded, already_loaded=True)

if not SkipWritingToConfig.skip:
shared.opts.data["sd_model_checkpoint"] = already_loaded.sd_checkpoint_info.title
shared.opts.data["sd_checkpoint_hash"] = already_loaded.sd_checkpoint_info.sha256

print(f"Using already loaded model {already_loaded.sd_checkpoint_info.title}: done in {timer.summary()}")
sd_vae.reload_vae_weights(already_loaded)
return model_data.sd_model
elif shared.opts.sd_checkpoints_limit > 1 and len(model_data.loaded_sd_models) < shared.opts.sd_checkpoints_limit:
print(f"Loading model {checkpoint_info.title} ({len(model_data.loaded_sd_models) + 1} out of {shared.opts.sd_checkpoints_limit})")
Expand All @@ -678,6 +683,10 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
sd_model = model_data.loaded_sd_models.pop()
model_data.sd_model = sd_model

sd_vae.base_vae = getattr(sd_model, "base_vae", None)
sd_vae.loaded_vae_file = getattr(sd_model, "loaded_vae_file", None)
sd_vae.checkpoint_info = sd_model.sd_checkpoint_info

print(f"Reusing loaded model {sd_model.sd_checkpoint_info.title} to load {checkpoint_info.title}")
return sd_model
else:
Expand Down
4 changes: 3 additions & 1 deletion modules/sd_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def load_vae_dict(filename, map_location):


def load_vae(model, vae_file=None, vae_source="from unknown source"):
global vae_dict, loaded_vae_file
global vae_dict, base_vae, loaded_vae_file
# save_settings = False

cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0
Expand Down Expand Up @@ -230,6 +230,8 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"):
restore_base_vae(model)

loaded_vae_file = vae_file
model.base_vae = base_vae
model.loaded_vae_file = loaded_vae_file


# don't call this from outside
Expand Down