diff --git a/utils/generate_backend_completeness.py b/utils/generate_backend_completeness.py index 39442259d..06e47270e 100644 --- a/utils/generate_backend_completeness.py +++ b/utils/generate_backend_completeness.py @@ -93,7 +93,7 @@ def parse_module(module_name: str, backend: str, nw_class_name: str) -> list[str def render_table_and_write_to_output( results: list[pl.DataFrame], title: str, output_filename: str ) -> None: - results = ( + results: pl.DataFrame = ( pl.concat(results) .with_columns(supported=pl.lit(":white_check_mark:")) .pivot(on="Backend", values="supported", index=["Method"]) @@ -103,7 +103,10 @@ def render_table_and_write_to_output( .sort("Method") ) - results = results.with_columns(polars=pl.lit(":white_check_mark:")) + backends = [c for c in results.columns if c != "Method"] + ["polars"] + results = results.with_columns(polars=pl.lit(":white_check_mark:")).select( + "Method", *sorted(backends) + ) with pl.Config( tbl_formatting="ASCII_MARKDOWN", @@ -175,7 +178,9 @@ def get_backend_completeness_table() -> None: continue render_table_and_write_to_output( - results=results, title=module_name.capitalize(), output_filename=module_name + results=results, + title=module_name.capitalize().replace("_", "."), + output_filename=module_name, )