Skip to content

Commit

Permalink
[Enhance] Add ‘config_name' as a supplement to the 'model_setting' (#…
Browse files Browse the repository at this point in the history
…2027)

* add config name

* fix lint

* update ut
  • Loading branch information
liuwenran authored Sep 11, 2023
1 parent bde84a0 commit 6fda2cc
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 21 deletions.
5 changes: 5 additions & 0 deletions demo/mmagic_inference_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ def parse_args():
type=int,
default=None,
help='Pretrained mmagic algorithm setting')
parser.add_argument(
'--config-name',
type=str,
default=None,
help='Pretrained mmagic algorithm config name')
parser.add_argument(
'--model-config',
type=str,
Expand Down
6 changes: 5 additions & 1 deletion demo/mmagic_inference_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -404,9 +404,13 @@
"\n",
"There are some different configs and checkpoints for one model.\n",
"\n",
"You could configure different settings by passing 'model_setting' to 'MMagicInferencer'. Every model's default setting is 0.\n",
"\n",
"Take conditional GAN model 'biggan' as an example. We have pretrained model for Cifar and Imagenet, and all pretrained models of 'biggan' are listed in its [metafile.yaml](../configs/biggan/metafile.yml)\n",
"\n",
"You could configure different settings by passing 'model_setting' to 'MMagicInferencer'. Every model's default setting is 0."
"There are six settings in this metafile. If you choose setting 1, then the config 'configs/biggan/biggan_ajbrock-sn_8xb32-1500kiters_imagenet1k-128x128.py' will be used. If 'model_setting' is not passed to 'MMagicInferencer', the config ‘configs/biggan/biggan_2xb25-500kiters_cifar10-32x32.py’ will be used by default.\n",
"\n",
"And you could also use 'config_name' to replace 'model_setting'. For example, you can init a MMagicInferencer with 'MMagicInferencer('biggan', config_name='biggan_2xb25-500kiters_cifar10-32x32')', which is the same with 'MMagicInferencer('biggan', model_setting=0)'."
]
},
{
Expand Down
5 changes: 3 additions & 2 deletions mmagic/apis/inferencers/diffusers_pipeline_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class DiffusersPipelineInferencer(BaseMMagicInferencer):
postprocess=[])

def preprocess(self,
text: InputsType,
text: InputsType = None,
negative_prompt: InputsType = None,
num_inference_steps: int = 20,
height=None,
Expand All @@ -37,7 +37,8 @@ def preprocess(self,
result(Dict): Results of preprocess.
"""
result = self.extra_parameters
result['prompt'] = text
if text:
result['prompt'] = text
if negative_prompt:
result['negative_prompt'] = negative_prompt
if num_inference_steps:
Expand Down
9 changes: 8 additions & 1 deletion mmagic/apis/mmagic_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ class MMagicInferencer:
def __init__(self,
model_name: str = None,
model_setting: int = None,
config_name: int = None,
model_config: str = None,
model_ckpt: str = None,
device: torch.device = None,
Expand All @@ -140,14 +141,15 @@ def __init__(self,
MMagicInferencer.init_inference_supported_models_cfg()
inferencer_kwargs = {}
inferencer_kwargs.update(
self._get_inferencer_kwargs(model_name, model_setting,
self._get_inferencer_kwargs(model_name, model_setting, config_name,
model_config, model_ckpt,
extra_parameters))
self.inferencer = Inferencers(
device=device, seed=seed, **inferencer_kwargs)

def _get_inferencer_kwargs(self, model_name: Optional[str],
model_setting: Optional[int],
config_name: Optional[int],
model_config: Optional[str],
model_ckpt: Optional[str],
extra_parameters: Optional[Dict]) -> Dict:
Expand All @@ -161,6 +163,11 @@ def _get_inferencer_kwargs(self, model_name: Optional[str],
if model_setting:
setting_to_use = model_setting
config_dir = cfgs['settings'][setting_to_use]['Config']
if config_name:
for setting in cfgs['settings']:
if setting['Name'] == config_name:
config_dir = setting['Config']
break
config_dir = config_dir[config_dir.find('configs'):]
if osp.exists(
osp.join(osp.dirname(__file__), '..', '..', config_dir)):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import platform

import pytest
import torch
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION

Expand All @@ -21,24 +20,11 @@
def test_diffusers_pipeline_inferencer():
cfg = dict(
model=dict(
type='DiffusionPipeline',
from_pretrained='runwayml/stable-diffusion-v1-5'))
type='DiffusionPipeline', from_pretrained='google/ddpm-cat-256'))

inferencer_instance = DiffusersPipelineInferencer(cfg, None)

def mock_infer(*args, **kwargs):
return dict(samples=torch.randn(1, 3, 64, 64))

inferencer_instance.model.infer = mock_infer

text_prompts = 'Japanese anime style, girl'
negative_prompt = 'bad face, bad hands'
result = inferencer_instance(
text=text_prompts,
negative_prompt=negative_prompt,
height=64,
width=64)
assert result[1][0].size == (64, 64)
result = inferencer_instance()
assert result[1][0].size == (256, 256)


def teardown_module():
Expand Down

0 comments on commit 6fda2cc

Please sign in to comment.