-
Notifications
You must be signed in to change notification settings - Fork 47
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
642 train 3d model with lucchi data #650
base: dev
Are you sure you want to change the base?
Conversation
* Update faq.md * Add described answer to the issue
Updates to SAM 3d training --------- Co-authored-by: Constantin Pape <[email protected]>
…er (#645) Add simple 3d wrapper and enable freezing the encoder in sam 3d wrapper, simplify lora support
Minor fix to trainable sam model functionality
Clean up interfaces related to 3d models and PEFT
assert os.path.exists(image_path), image_path | ||
|
||
# Perform segmentation only on the semantic class | ||
for i, (semantic_class_name, _) in enumerate(semantic_class_map.items()): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would suggest to remove this part already, it doesn't make sense in the context here.
def predict(args): | ||
|
||
device = "cuda" if torch.cuda.is_available() else "cpu" | ||
if args.checkpoint_path is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure why you would ever run prediction without a checkpoint. I would not make this optional.
lora_rank=4, | ||
model_type=args.model_type, | ||
checkpoint_path=cp_path | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will not work to actually load the checkpoint. Please read the code I send you carefully and see how I use torch_em.util.load_model
to get the checkpoint. Not that this quite complex way is now necessary because of the way how we mix the 3d adapter and LoRA. I will think about a better way at some point, but for now we need to use this work-around.
return raw | ||
|
||
|
||
class LucchiSegmentationDataset(SegmentationDataset): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can now be removed.
… default_sam_dataset
No description provided.