-
Notifications
You must be signed in to change notification settings - Fork 26
/
Copy pathsaml.go
131 lines (113 loc) · 3.42 KB
/
saml.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
package beyond
import (
"context"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"flag"
"net/http"
"net/url"
"github.com/crewjam/saml"
"github.com/crewjam/saml/samlsp"
"github.com/pkg/errors"
dsig "github.com/russellhaering/goxmldsig"
)
var (
samlCert = flag.String("saml-cert-file", "example/myservice.cert", "SAML SP path to cert.pem")
samlKey = flag.String("saml-key-file", "example/myservice.key", "SAML SP path to key.pem")
samlID = flag.String("saml-entity-id", "", "SAML SP entity ID (blank defaults to beyond-host)")
samlIDP = flag.String("saml-metadata-url", "", "SAML metadata URL from IdP (blank disables SAML)")
samlNIDF = flag.String("saml-nameid-format", "email", "SAML SP option: {email, persistent, transient, unspecified}")
samlAttr = flag.String("saml-session-key", "email", "SAML attribute to map from session")
samlSignRequests = flag.Bool("saml-sign-requests", false, "SAML SP signs authentication requests")
samlSignMethod = flag.String("saml-signature-method", "", "SAML SP option: {sha1, sha256, sha512}")
samlSP *samlsp.Middleware
)
func samlSetup() error {
if *samlIDP == "" {
return nil
}
if *samlID == "" {
*samlID = *host
}
keyPair, err := tls.LoadX509KeyPair(*samlCert, *samlKey)
if err != nil {
return err
}
keyPair.Leaf, err = x509.ParseCertificate(keyPair.Certificate[0])
if err != nil {
return err
}
idpMetadataURL, err := url.Parse(*samlIDP)
if err != nil {
return err
}
idpMetadata, err := samlsp.FetchMetadata(
context.Background(), http.DefaultClient,
*idpMetadataURL)
if err != nil {
return err
}
rootURL, _ := url.Parse("https://" + *host)
if err != nil {
return err
}
samlSP, err = samlsp.New(samlsp.Options{
EntityID: *samlID,
SignRequest: *samlSignRequests,
URL: *rootURL,
Certificate: keyPair.Leaf,
IDPMetadata: idpMetadata,
Key: keyPair.PrivateKey.(*rsa.PrivateKey),
AllowIDPInitiated: true,
})
if err != nil {
return err
}
switch *samlNIDF {
case "email":
samlSP.ServiceProvider.AuthnNameIDFormat = saml.EmailAddressNameIDFormat
case "persistent":
samlSP.ServiceProvider.AuthnNameIDFormat = saml.PersistentNameIDFormat
case "transient":
samlSP.ServiceProvider.AuthnNameIDFormat = saml.TransientNameIDFormat
case "unspecified":
samlSP.ServiceProvider.AuthnNameIDFormat = saml.UnspecifiedNameIDFormat
case "":
default:
return errors.Errorf("invalid saml-nameid-format: \"%s\"", *samlNIDF)
}
switch *samlSignMethod {
case "sha1":
samlSP.ServiceProvider.SignatureMethod = dsig.RSASHA1SignatureMethod
case "sha256":
samlSP.ServiceProvider.SignatureMethod = dsig.RSASHA256SignatureMethod
case "sha512":
samlSP.ServiceProvider.SignatureMethod = dsig.RSASHA512SignatureMethod
case "":
default:
return errors.Errorf("invalid saml-signature-method: \"%s\"", *samlSignMethod)
}
return nil
}
func samlFilter(w http.ResponseWriter, r *http.Request) bool {
samlSession, _ := samlSP.Session.GetSession(r)
if _, ok := samlSession.(samlsp.SessionWithAttributes); !ok {
// sessions without mappings will redirect infinitely
return false
}
samlAttributes := samlSession.(samlsp.SessionWithAttributes).GetAttributes()
user := samlAttributes.Get(*samlAttr)
if user == "" {
// nil IdP assertion unlikely
return false
}
session, err := store.Get(r, *cookieName)
if err != nil {
session = store.New(*cookieName)
}
session.Values["user"] = user
session.Save(w)
samlSP.Session.DeleteSession(w, r)
return true
}