-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathSupport_Vector_Machines_Guide.m
96 lines (62 loc) · 3.39 KB
/
Support_Vector_Machines_Guide.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
%%%-------------Support Vector Machines (SVM)
%---------------Importing Dataset
data = readtable('Datasets\Social_Network_Ads.csv');
%---------------Feature Scaling (Standardization Method)
stand_age = (data.Age - mean(data.Age))/std(data.Age);
data.Age = stand_age;
stand_estimted_salary = (data.EstimatedSalary - mean(data.EstimatedSalary))/std(data.EstimatedSalary);
data.EstimatedSalary = stand_estimted_salary;
%---------------Classifying Data
classification_model = fitcsvm(data,'Purchased~Age+EstimatedSalary');
%--------------- Customization for classifier
% KernelFucntions: linear, gaussian, polynomial
% classification_model = fitcsvm(data,'Purchased~Age+EstimatedSalary','KernelFunction','gaussian');
% OutlierFraction: Can be applied on a data with outliers
% Round 0.10 of the instances are outliers in the data. Remove 0.10 of the instances as outliers.
% classification_model = fitcsvm(data,'Purchased~Age+EstimatedSalary','OutlierFraction',0.1);
%---------------Partitioning
cv = cvpartition(classification_model.NumObservations, 'HoldOut', 0.2);
cross_validated_model = crossval(classification_model,'cvpartition',cv);
%---------------Predictions
Predictions = predict(cross_validated_model.Trained{1},data(test(cv),1:end-1));
%---------------Analyzing the Results
Results = confusionmat(cross_validated_model.Y(test(cv)),Predictions);
%---------------Visualizing Training Results
labels = unique(data.Purchased);
classifier_name = 'SVM (Training Results)';
Age_range = min(data.Age(training(cv)))-1:0.01:max(data.Age(training(cv)))+1;
Estimated_salary_range = min(data.EstimatedSalary(training(cv)))-1:0.01:max(data.EstimatedSalary(training(cv)))+1;
[xx1, xx2] = meshgrid(Age_range,Estimated_salary_range);
XGrid = [xx1(:) xx2(:)];
predictions_meshgrid = predict(cross_validated_model.Trained{1},XGrid);
gscatter(xx1(:), xx2(:), predictions_meshgrid,'rgb');
hold on
training_data = data(training(cv),:);
Y = ismember(training_data.Purchased,labels{1});
scatter(training_data.Age(Y),training_data.EstimatedSalary(Y), 'o' , 'MarkerEdgeColor', 'black', 'MarkerFaceColor', 'red');
scatter(training_data.Age(~Y),training_data.EstimatedSalary(~Y) , 'o' , 'MarkerEdgeColor', 'black', 'MarkerFaceColor', 'green');
xlabel('Age');
ylabel('Estimated Salary');
title(classifier_name);
legend off, axis tight
legend(labels,'Location',[0.45,0.01,0.45,0.05],'Orientation','Horizontal');
%---------------Visualizing Test Results
labels = unique(data.Purchased);
classifier_name = 'SVM (Testing Results)';
Age_range = min(data.Age(training(cv)))-1:0.01:max(data.Age(training(cv)))+1;
Estimated_salary_range = min(data.EstimatedSalary(training(cv)))-1:0.01:max(data.EstimatedSalary(training(cv)))+1;
[xx1, xx2] = meshgrid(Age_range,Estimated_salary_range);
XGrid = [xx1(:) xx2(:)];
predictions_meshgrid = predict(cross_validated_model.Trained{1},XGrid);
figure
gscatter(xx1(:), xx2(:), predictions_meshgrid,'rgb');
hold on
testing_data = data(test(cv),:);
Y = ismember(testing_data.Purchased,labels{1});
scatter(testing_data.Age(Y),testing_data.EstimatedSalary(Y), 'o' , 'MarkerEdgeColor', 'black', 'MarkerFaceColor', 'red');
scatter(testing_data.Age(~Y),testing_data.EstimatedSalary(~Y) , 'o' , 'MarkerEdgeColor', 'black', 'MarkerFaceColor', 'green');
xlabel('Age');
ylabel('Estimated Salary');
title(classifier_name);
legend off, axis tight
legend(labels,'Location',[0.45,0.01,0.45,0.05],'Orientation','Horizontal');