Skip to content

Commit

Permalink
Add torch flop counter to profiler (#693)
Browse files Browse the repository at this point in the history
* Add torch flop counter

* Moving fvcore to a try/catch

* Fix undesired behavior of `transform.PreprocessCfg`

* Remove cls_embed arg from forward/encode_image fns

* Fix `model.get_model_preprocess_cfg`

* Move CLIPA weights to UCSC-VLAA org

* Remove outdated comment in pretrained.py

* Fix arg not being passed to `image_transform` from `image_transform_v2`

* remove hierarchical flop counting

---------

Co-authored-by: Santiago Castro <[email protected]>
Co-authored-by: Ross Wightman <[email protected]>
  • Loading branch information
3 people authored Oct 24, 2023
1 parent 4db1f41 commit 2a46bd9
Showing 1 changed file with 120 additions and 34 deletions.
154 changes: 120 additions & 34 deletions src/training/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
import torch
import open_clip
import pandas as pd
from fvcore.nn import FlopCountAnalysis, flop_count_str, ActivationCountAnalysis

from torch.utils.flop_counter import FlopCounterMode
try:
import fvcore
except:
fvcore = None

parser = argparse.ArgumentParser(description='OpenCLIP Profiler')

Expand All @@ -13,6 +16,8 @@
help='model(s) to profile')
parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
help='Output csv file for results')
parser.add_argument('--profiler', default='torch', type=str, choices=['torch', 'fvcore'])
parser.add_argument('--batch-size', default=1, type=int, help='Batch size for profiling')


def profile_fvcore(
Expand All @@ -28,12 +33,12 @@ def profile_fvcore(
device, dtype = next(model.parameters()).device, next(model.parameters()).dtype
example_image_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype)
example_text_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64)
fca = FlopCountAnalysis(model, (example_image_input, example_text_input))
aca = ActivationCountAnalysis(model, (example_image_input, example_text_input))
fca = fvcore.nn.FlopCountAnalysis(model, (example_image_input, example_text_input))
aca = fvcore.nn.ActivationCountAnalysis(model, (example_image_input, example_text_input))
if detailed:
fcs = flop_count_str(fca)
fcs = fvcore.nn.flop_count_str(fca)
print(fcs)
return fca.total(), aca.total()
return fca.total() / batch_size, aca.total() / batch_size


def profile_fvcore_text(
Expand All @@ -47,12 +52,12 @@ def profile_fvcore_text(
model = model.to('cpu')
device = next(model.parameters()).device
example_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64)
fca = FlopCountAnalysis(model, example_input)
aca = ActivationCountAnalysis(model, example_input)
fca = fvcore.nn.FlopCountAnalysis(model, example_input)
aca = fvcore.nn.ActivationCountAnalysis(model, example_input)
if detailed:
fcs = flop_count_str(fca)
fcs = fvcore.nn.flop_count_str(fca)
print(fcs)
return fca.total(), aca.total()
return fca.total() / batch_size, aca.total() / batch_size


def profile_fvcore_image(
Expand All @@ -66,19 +71,64 @@ def profile_fvcore_image(
model = model.to('cpu')
device, dtype = next(model.parameters()).device, next(model.parameters()).dtype
example_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype)
fca = FlopCountAnalysis(model, example_input)
aca = ActivationCountAnalysis(model, example_input)
fca = fvcore.nn.FlopCountAnalysis(model, example_input)
aca = fvcore.nn.ActivationCountAnalysis(model, example_input)
if detailed:
fcs = flop_count_str(fca)
fcs = fvcore.nn.flop_count_str(fca)
print(fcs)
return fca.total(), aca.total()
return fca.total() / batch_size, aca.total() / batch_size


def profile_torch_image(model, image_input_size, batch_size=1, force_cpu=False):
"""Profile the image encoder using torch.utils.flop_counter"""
if force_cpu:
model = model.to('cpu')
device, dtype = next(model.parameters()).device, next(model.parameters()).dtype
example_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype)

flop_counter = FlopCounterMode()
with flop_counter:
model(example_input)
total_flops = sum(flop_counter.get_flop_counts()['Global'].values())
return total_flops / batch_size


def profile_torch_text(model, text_input_size, batch_size=1, force_cpu=False):
"""Profile the text encoder using torch.utils.flop_counter"""
if force_cpu:
model = model.to('cpu')
device = next(model.parameters()).device
example_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64)

flop_counter = FlopCounterMode()
with flop_counter:
model(example_input)
total_flops = sum(flop_counter.get_flop_counts()['Global'].values())
return total_flops / batch_size


def profile_torch(model, text_input_size, image_input_size, batch_size=1, force_cpu=False):
"""Profile the full model using torch.utils.flop_counter"""
if force_cpu:
model = model.to('cpu')
device, dtype = next(model.parameters()).device, next(model.parameters()).dtype
image_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype)
text_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64)

flop_counter = FlopCounterMode()
with flop_counter:
model(image_input, text_input)
total_flops = sum(flop_counter.get_flop_counts()['Global'].values())
return total_flops / batch_size


def count_params(model):
return sum([m.numel() for m in model.parameters()])


def profile_model(model_name):
def profile_model(model_name, batch_size=1, profiler='torch'):
assert profiler in ['torch', 'fvcore'], 'Only torch and fvcore profilers are supported'
if profiler == 'fvcore':
assert fvcore is not None, 'Please install fvcore.'
model = open_clip.create_model(model_name, force_custom_text=True, pretrained_hf=False)
model.eval()
if torch.cuda.is_available():
Expand All @@ -88,7 +138,10 @@ def profile_model(model_name):
image_input_size = (3,) + tuple(model.visual.image_size[-2:])
else:
image_input_size = (3, model.visual.image_size, model.visual.image_size)

text_input_size = (77,)
if hasattr(model, 'context_length') and model.context_length:
text_input_size = (model.context_length,)

results = {}
results['model'] = model_name
Expand All @@ -110,24 +163,40 @@ def profile_model(model_name):
while retries:
retries -= 1
try:
macs, acts = profile_fvcore(
model, image_input_size=image_input_size, text_input_size=text_input_size, force_cpu=not retries)

image_macs, image_acts = profile_fvcore_image(
model.visual, image_input_size=image_input_size, force_cpu=not retries)

text_macs, text_acts = profile_fvcore_text(
model.text, text_input_size=text_input_size, force_cpu=not retries)

results['gmacs'] = round(macs / 1e9, 2)
results['macts'] = round(acts / 1e6, 2)
results['mparams'] = round(count_params(model) / 1e6, 2)
results['image_gmacs'] = round(image_macs / 1e9, 2)
results['image_macts'] = round(image_acts / 1e6, 2)
results['image_mparams'] = round(count_params(model.visual) / 1e6, 2)
results['text_gmacs'] = round(text_macs / 1e9, 2)
results['text_macts'] = round(text_acts / 1e6, 2)
results['text_mparams'] = round(count_params(model.text) / 1e6, 2)

if profiler == 'fvcore':
macs, acts = profile_fvcore(
model, image_input_size=image_input_size, text_input_size=text_input_size, force_cpu=not retries, batch_size=batch_size)

image_macs, image_acts = profile_fvcore_image(
model.visual, image_input_size=image_input_size, force_cpu=not retries, batch_size=batch_size)

text_macs, text_acts = profile_fvcore_text(
model.text, text_input_size=text_input_size, force_cpu=not retries, batch_size=batch_size)

results['gmacs'] = round(macs / 1e9, 2)
results['macts'] = round(acts / 1e6, 2)

results['image_gmacs'] = round(image_macs / 1e9, 2)
results['image_macts'] = round(image_acts / 1e6, 2)

results['text_gmacs'] = round(text_macs / 1e9, 2)
results['text_macts'] = round(text_acts / 1e6, 2)
elif profiler == 'torch':
image_flops = profile_torch_image(
model.visual, image_input_size=image_input_size, force_cpu=not retries, batch_size=batch_size)
text_flops = profile_torch_text(
model.text, text_input_size=text_input_size, force_cpu=not retries, batch_size=batch_size)
total_flops = profile_torch(
model, image_input_size=image_input_size, text_input_size=text_input_size, force_cpu=not retries, batch_size=batch_size)

results['gflops'] = round(total_flops / 1e9, 2)
results['image_gflops'] = round(image_flops / 1e9, 2)
results['text_gflops'] = round(text_flops / 1e9, 2)

except RuntimeError as e:
pass
return results
Expand All @@ -143,16 +212,33 @@ def main():
parsed_model = args.model.split(',')

results = []
models_with_errors = []
for m in parsed_model:
row = profile_model(m)
results.append(row)
print('='*100)
print(f'Profiling {m}')
try:
row = profile_model(m, batch_size=args.batch_size, profiler=args.profiler)
results.append(row)
except Exception as e:
print(f'Error profiling {m}: {e}')
import traceback
traceback.print_exc()
models_with_errors.append(m)

df = pd.DataFrame(results, columns=results[0].keys())
df = df.sort_values('gmacs')
if 'gmacs' in df.columns:
df = df.sort_values('gmacs')
else:
df = df.sort_values('gflops')

print('='*100)
print('Done.')
print(df)
if args.results_file:
df.to_csv(args.results_file, index=False)

print('Models with errors:', models_with_errors)


if __name__ == '__main__':
main()

0 comments on commit 2a46bd9

Please sign in to comment.