diff --git a/src/training/profiler.py b/src/training/profiler.py index f10372cde..6c90a5270 100644 --- a/src/training/profiler.py +++ b/src/training/profiler.py @@ -3,8 +3,11 @@ import torch import open_clip import pandas as pd -from fvcore.nn import FlopCountAnalysis, flop_count_str, ActivationCountAnalysis - +from torch.utils.flop_counter import FlopCounterMode +try: + import fvcore +except: + fvcore = None parser = argparse.ArgumentParser(description='OpenCLIP Profiler') @@ -13,6 +16,8 @@ help='model(s) to profile') parser.add_argument('--results-file', default='', type=str, metavar='FILENAME', help='Output csv file for results') +parser.add_argument('--profiler', default='torch', type=str, choices=['torch', 'fvcore']) +parser.add_argument('--batch-size', default=1, type=int, help='Batch size for profiling') def profile_fvcore( @@ -28,12 +33,12 @@ def profile_fvcore( device, dtype = next(model.parameters()).device, next(model.parameters()).dtype example_image_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype) example_text_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64) - fca = FlopCountAnalysis(model, (example_image_input, example_text_input)) - aca = ActivationCountAnalysis(model, (example_image_input, example_text_input)) + fca = fvcore.nn.FlopCountAnalysis(model, (example_image_input, example_text_input)) + aca = fvcore.nn.ActivationCountAnalysis(model, (example_image_input, example_text_input)) if detailed: - fcs = flop_count_str(fca) + fcs = fvcore.nn.flop_count_str(fca) print(fcs) - return fca.total(), aca.total() + return fca.total() / batch_size, aca.total() / batch_size def profile_fvcore_text( @@ -47,12 +52,12 @@ def profile_fvcore_text( model = model.to('cpu') device = next(model.parameters()).device example_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64) - fca = FlopCountAnalysis(model, example_input) - aca = ActivationCountAnalysis(model, example_input) + fca = fvcore.nn.FlopCountAnalysis(model, example_input) + aca = fvcore.nn.ActivationCountAnalysis(model, example_input) if detailed: - fcs = flop_count_str(fca) + fcs = fvcore.nn.flop_count_str(fca) print(fcs) - return fca.total(), aca.total() + return fca.total() / batch_size, aca.total() / batch_size def profile_fvcore_image( @@ -66,19 +71,64 @@ def profile_fvcore_image( model = model.to('cpu') device, dtype = next(model.parameters()).device, next(model.parameters()).dtype example_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype) - fca = FlopCountAnalysis(model, example_input) - aca = ActivationCountAnalysis(model, example_input) + fca = fvcore.nn.FlopCountAnalysis(model, example_input) + aca = fvcore.nn.ActivationCountAnalysis(model, example_input) if detailed: - fcs = flop_count_str(fca) + fcs = fvcore.nn.flop_count_str(fca) print(fcs) - return fca.total(), aca.total() + return fca.total() / batch_size, aca.total() / batch_size + + +def profile_torch_image(model, image_input_size, batch_size=1, force_cpu=False): + """Profile the image encoder using torch.utils.flop_counter""" + if force_cpu: + model = model.to('cpu') + device, dtype = next(model.parameters()).device, next(model.parameters()).dtype + example_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype) + + flop_counter = FlopCounterMode() + with flop_counter: + model(example_input) + total_flops = sum(flop_counter.get_flop_counts()['Global'].values()) + return total_flops / batch_size + + +def profile_torch_text(model, text_input_size, batch_size=1, force_cpu=False): + """Profile the text encoder using torch.utils.flop_counter""" + if force_cpu: + model = model.to('cpu') + device = next(model.parameters()).device + example_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64) + + flop_counter = FlopCounterMode() + with flop_counter: + model(example_input) + total_flops = sum(flop_counter.get_flop_counts()['Global'].values()) + return total_flops / batch_size + + +def profile_torch(model, text_input_size, image_input_size, batch_size=1, force_cpu=False): + """Profile the full model using torch.utils.flop_counter""" + if force_cpu: + model = model.to('cpu') + device, dtype = next(model.parameters()).device, next(model.parameters()).dtype + image_input = torch.ones((batch_size,) + image_input_size, device=device, dtype=dtype) + text_input = torch.ones((batch_size,) + text_input_size, device=device, dtype=torch.int64) + + flop_counter = FlopCounterMode() + with flop_counter: + model(image_input, text_input) + total_flops = sum(flop_counter.get_flop_counts()['Global'].values()) + return total_flops / batch_size def count_params(model): return sum([m.numel() for m in model.parameters()]) - -def profile_model(model_name): +def profile_model(model_name, batch_size=1, profiler='torch'): + assert profiler in ['torch', 'fvcore'], 'Only torch and fvcore profilers are supported' + if profiler == 'fvcore': + assert fvcore is not None, 'Please install fvcore.' model = open_clip.create_model(model_name, force_custom_text=True, pretrained_hf=False) model.eval() if torch.cuda.is_available(): @@ -88,7 +138,10 @@ def profile_model(model_name): image_input_size = (3,) + tuple(model.visual.image_size[-2:]) else: image_input_size = (3, model.visual.image_size, model.visual.image_size) + text_input_size = (77,) + if hasattr(model, 'context_length') and model.context_length: + text_input_size = (model.context_length,) results = {} results['model'] = model_name @@ -110,24 +163,40 @@ def profile_model(model_name): while retries: retries -= 1 try: - macs, acts = profile_fvcore( - model, image_input_size=image_input_size, text_input_size=text_input_size, force_cpu=not retries) - - image_macs, image_acts = profile_fvcore_image( - model.visual, image_input_size=image_input_size, force_cpu=not retries) - - text_macs, text_acts = profile_fvcore_text( - model.text, text_input_size=text_input_size, force_cpu=not retries) - - results['gmacs'] = round(macs / 1e9, 2) - results['macts'] = round(acts / 1e6, 2) results['mparams'] = round(count_params(model) / 1e6, 2) - results['image_gmacs'] = round(image_macs / 1e9, 2) - results['image_macts'] = round(image_acts / 1e6, 2) results['image_mparams'] = round(count_params(model.visual) / 1e6, 2) - results['text_gmacs'] = round(text_macs / 1e9, 2) - results['text_macts'] = round(text_acts / 1e6, 2) results['text_mparams'] = round(count_params(model.text) / 1e6, 2) + + if profiler == 'fvcore': + macs, acts = profile_fvcore( + model, image_input_size=image_input_size, text_input_size=text_input_size, force_cpu=not retries, batch_size=batch_size) + + image_macs, image_acts = profile_fvcore_image( + model.visual, image_input_size=image_input_size, force_cpu=not retries, batch_size=batch_size) + + text_macs, text_acts = profile_fvcore_text( + model.text, text_input_size=text_input_size, force_cpu=not retries, batch_size=batch_size) + + results['gmacs'] = round(macs / 1e9, 2) + results['macts'] = round(acts / 1e6, 2) + + results['image_gmacs'] = round(image_macs / 1e9, 2) + results['image_macts'] = round(image_acts / 1e6, 2) + + results['text_gmacs'] = round(text_macs / 1e9, 2) + results['text_macts'] = round(text_acts / 1e6, 2) + elif profiler == 'torch': + image_flops = profile_torch_image( + model.visual, image_input_size=image_input_size, force_cpu=not retries, batch_size=batch_size) + text_flops = profile_torch_text( + model.text, text_input_size=text_input_size, force_cpu=not retries, batch_size=batch_size) + total_flops = profile_torch( + model, image_input_size=image_input_size, text_input_size=text_input_size, force_cpu=not retries, batch_size=batch_size) + + results['gflops'] = round(total_flops / 1e9, 2) + results['image_gflops'] = round(image_flops / 1e9, 2) + results['text_gflops'] = round(text_flops / 1e9, 2) + except RuntimeError as e: pass return results @@ -143,16 +212,33 @@ def main(): parsed_model = args.model.split(',') results = [] + models_with_errors = [] for m in parsed_model: - row = profile_model(m) - results.append(row) + print('='*100) + print(f'Profiling {m}') + try: + row = profile_model(m, batch_size=args.batch_size, profiler=args.profiler) + results.append(row) + except Exception as e: + print(f'Error profiling {m}: {e}') + import traceback + traceback.print_exc() + models_with_errors.append(m) df = pd.DataFrame(results, columns=results[0].keys()) - df = df.sort_values('gmacs') + if 'gmacs' in df.columns: + df = df.sort_values('gmacs') + else: + df = df.sort_values('gflops') + + print('='*100) + print('Done.') print(df) if args.results_file: df.to_csv(args.results_file, index=False) + print('Models with errors:', models_with_errors) + if __name__ == '__main__': main()