Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

how to load from cogvideo official pretrained model? #18

Open
jianlong-yuan opened this issue Sep 23, 2024 · 4 comments
Open

how to load from cogvideo official pretrained model? #18

jianlong-yuan opened this issue Sep 23, 2024 · 4 comments

Comments

@jianlong-yuan
Copy link

No description provided.

@bubbliiiing
Copy link
Collaborator

Only support the t2v model.

@jianlong-yuan
Copy link
Author

How to resume dataset ?

@bubbliiiing
Copy link
Collaborator

Set --resume_from_checkpoint="latest"

@jianlong-yuan
Copy link
Author

Set --resume_from_checkpoint="latest"

By examining the code, it can be seen that only the parameters and epoch have been loaded, and the data iters has not been resumed.

if args.resume_from_checkpoint:
    if args.resume_from_checkpoint != "latest":
        path = os.path.basename(args.resume_from_checkpoint)
    else:
        # Get the most recent checkpoint
        dirs = os.listdir(args.output_dir)
        dirs = [d for d in dirs if d.startswith("checkpoint")]
        dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
        path = dirs[-1] if len(dirs) > 0 else None

    if path is None:
        accelerator.print(
            f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
        )
        args.resume_from_checkpoint = None
        initial_global_step = 0
    else:
        global_step = int(path.split("-")[1])

        initial_global_step = global_step

        pkl_path = os.path.join(os.path.join(args.output_dir, path), "sampler_pos_start.pkl")
        if os.path.exists(pkl_path):
            with open(pkl_path, 'rb') as file:
                _, first_epoch = pickle.load(file)
        else:
            first_epoch = global_step // num_update_steps_per_epoch
        print(f"Load pkl from {pkl_path}. Get first_epoch = {first_epoch}.")

        accelerator.print(f"Resuming from checkpoint {path}")
        accelerator.load_state(os.path.join(args.output_dir, path))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants