-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathcreate_tuning_data.dart
190 lines (155 loc) · 5.02 KB
/
create_tuning_data.dart
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
// Copyright (c) 2024, the Dart project authors. Please see the AUTHORS file
// for details. All rights reserved. Use of this source code is governed by a
// BSD-style license that can be found in the LICENSE file.
import 'dart:convert';
import 'dart:io';
import 'package:github/github.dart';
import 'package:sdk_triage_bot/src/common.dart';
import 'package:sdk_triage_bot/src/github.dart';
import 'package:sdk_triage_bot/src/prompts.dart';
// Here, we download 500-1000 already triaged github issues and create a file
// suitable for tuning a Gemini model (via https://aistudio.google.com/).
//
// - make sure we have more of the more common areas
// - make sure we have at least 10 items from each area
const Map<String, int> areaSampleCount = {
'area-analyzer': 100,
'area-core-library': 100,
'area-front-end': 100,
'area-vm': 100,
'area-web': 100,
//
'area-dart-cli': 50,
'area-infrastructure': 50,
'area-language': 50,
'area-test': 50,
//
'area-dart2wasm': 25,
'area-meta': 25,
'area-pkg': 25,
//
'area-build': 10,
'area-google3': 10,
'area-intellij': 10,
'area-native-interop': 10,
'area-sdk': 10,
'area-tools': 10,
};
void main(List<String> args) async {
print('Building tuning data...');
print('');
// download issues
final issueMap = <int, Issue>{};
for (var entry in areaSampleCount.entries) {
final areaLabel = entry.key;
final count = entry.value;
final results = await downloadIssues(areaLabel, count);
print('Downloaded ${results.length} issues from $areaLabel');
// use the map to remove dups
for (final issue in results) {
issueMap[issue.number] = issue;
}
}
// sort by issue number
final issues = issueMap.values.toList();
issues.sort((a, b) => b.number - a.number);
// emit training file
final trainingFileCsv = File('tool/training.csv');
final trainingFileJsonl = File('tool/training.jsonl');
final trainingFileDesc = File('tool/training.txt');
final trainingDataCsv =
issues.map((issue) => issue.trainingRowCSV).join('\n');
trainingFileCsv.writeAsStringSync('$trainingDataCsv\n');
final trainingDataJsonl =
issues.map((issue) => issue.trainingRowJsonl).join('\n');
trainingFileJsonl.writeAsStringSync('$trainingDataJsonl\n');
final trainingDesc = issues.map((issue) => issue.trainingDesc).join('\n');
trainingFileDesc.writeAsStringSync('$trainingDesc\n');
print('');
print('Wrote training data to ${trainingFileCsv.path} and '
'${trainingFileJsonl.path}.');
exit(0);
}
Future<List<Issue>> downloadIssues(
String areaLabel,
int count, {
bool includeClosed = false,
}) async {
var result = await fetchIssues(areaLabel, includeClosed: includeClosed);
final issues = <Issue>[];
while (result.issues.isNotEmpty) {
for (final issue in result.issues) {
issues.add(issue);
if (issues.length >= count) {
return issues;
}
}
if (!result.hasNext) {
break;
} else {
result = await fetchIssues(
areaLabel,
includeClosed: includeClosed,
cursor: result.cursor,
);
}
}
return issues;
}
extension on Issue {
String get trainingRowCSV {
final bodyValue = trimmedBody(bodyText!);
final filteredLabels = labels.map((l) => l.name).where((label) {
return label.startsWith('area-') || label.startsWith('type-');
}).toList();
// csv encode
final input = assignAreaPrompt(title: title, body: bodyValue);
final output = filteredLabels.join(', ');
return '${csvEncode(input)},${csvEncode(output)}';
}
String get trainingRowJsonl {
final bodyValue = trimmedBody(bodyText!);
final filteredLabels = labels.map((l) => l.name).where((label) {
return label.startsWith('area-') || label.startsWith('type-');
}).toList();
final input = assignAreaPrompt(title: title, body: bodyValue);
final output = filteredLabels.join(', ');
return jsonEncode({
'messages': [
{'role': 'user', 'content': input},
{'role': 'model', 'content': output},
],
});
}
String get trainingDesc {
var shortTitle = title;
if (shortTitle.length > 80) {
shortTitle = '${shortTitle.substring(0, 80)}...';
}
final filteredLabels = labels.map((l) => l.name).where((label) {
return label.startsWith('area-') || label.startsWith('type-');
}).toList();
return '[$number] "$shortTitle": ${filteredLabels.join(', ')}';
}
// ignore: unused_element
String get markdownDesc {
final filteredLabels = labels.map((l) => l.name).where((label) {
return label.startsWith('area-') ||
label.startsWith('type-') ||
label == 'needs-info';
}).toList()
..sort();
final descriptions = filteredLabels.map((l) => '`$l`').toList();
return '| #$number | ${descriptions.join(', ')} |';
}
}
String csvEncode(String str) {
str = str.replaceAll('\n', r' \n ');
if (str.contains('"')) {
str = str.replaceAll('"', '""');
}
if (str.contains("'") || str.contains(' ') || str.contains('"')) {
return '"$str"';
}
return str;
}