diff --git a/command/client.go b/command/client.go index c100d5e..b2f7149 100644 --- a/command/client.go +++ b/command/client.go @@ -24,6 +24,8 @@ const ( MSG_ASK_HOSTNAME = "Key server's host name" MSG_ASK_PORT = "Key server's port number" MSG_ASK_CA = "(Optional) PEM-encoded CA certificate of key server" + MSG_ASK_CLIENT_CERT = "If key server will validate client identity, enter path to PEM-encoded client certificate" + MSG_ASK_CLIENT_CERT_KEY = "If key server will validate client identity, enter path to PEM-encoded client key" MSG_ASK_DIFF_HOST = `Previously, this computer used "%s" as its key server; now you wish to use "%s". Only a single key server can be used to unlock all encrypted disks on this computer. Do you wish to proceed and switch to the new key server?` @@ -59,7 +61,7 @@ The encryption sequence will carry out the following tasks: ) // Prompt user to enter key server's CA file, host name, and port. Defaults are provided by existing configuration. -func PromptForKeyServer() (sysconf *sys.Sysconfig, caFile, host string, port int, err error) { +func PromptForKeyServer() (sysconf *sys.Sysconfig, caFile, certFile, certKeyFile, host string, port int, err error) { sysconf, err = sys.ParseSysconfigFile(CLIENT_CONFIG_PATH, true) if err != nil { return @@ -76,6 +78,16 @@ func PromptForKeyServer() (sysconf *sys.Sysconfig, caFile, host string, port int if caFile = sys.InputAbsFilePath(false, defaultCAFile, MSG_ASK_CA); caFile == "" { caFile = defaultCAFile } + defaultCertFile := sysconf.GetString(keyserv.CLIENT_CONF_CERT, "") + if certFile = sys.InputAbsFilePath(false, defaultCertFile, MSG_ASK_CLIENT_CERT); certFile == "" { + certFile = defaultCertFile + } + if certFile != "" { + defaultCertKeyFile := sysconf.GetString(keyserv.CLIENT_CONF_CERT_KEY, "") + if certKeyFile = sys.InputAbsFilePath(false, defaultCertKeyFile, MSG_ASK_CLIENT_CERT_KEY); certKeyFile == "" { + certKeyFile = defaultCertKeyFile + } + } return } @@ -84,7 +96,7 @@ func EncryptFS() error { sys.LockMem() // Prompt for connection details - sysconf, caFile, host, port, err := PromptForKeyServer() + sysconf, caFile, certFile, certKeyFile, host, port, err := PromptForKeyServer() if err != nil { return err } @@ -96,7 +108,7 @@ func EncryptFS() error { } // Check server connectivity before commencing encryption - client, password, err := ConnectToKeyServer(caFile, fmt.Sprintf("%s:%d", host, port)) + client, password, err := ConnectToKeyServer(caFile, certFile, certKeyFile, fmt.Sprintf("%s:%d", host, port)) if err != nil { return err } @@ -140,6 +152,8 @@ func EncryptFS() error { sysconf.Set(keyserv.CLIENT_CONF_HOST, host) sysconf.Set(keyserv.CLIENT_CONF_PORT, strconv.Itoa(port)) sysconf.Set(keyserv.CLIENT_CONF_CA, caFile) + sysconf.Set(keyserv.CLIENT_CONF_CERT, certFile) + sysconf.Set(keyserv.CLIENT_CONF_CERT_KEY, certKeyFile) if err := ioutil.WriteFile(CLIENT_CONFIG_PATH, []byte(sysconf.ToText()), 0600); err != nil { return fmt.Errorf(MSG_E_SAVE_SYSCONF, CLIENT_CONFIG_PATH, err) } @@ -151,11 +165,11 @@ func EncryptFS() error { // Sub-command: forcibly unlock all file systems that have their keys on a key server. func ManOnlineUnlockFS() error { sys.LockMem() - _, caFile, host, port, err := PromptForKeyServer() + _, caFile, certFile, certKeyFile, host, port, err := PromptForKeyServer() if err != nil { return err } - client, password, err := ConnectToKeyServer(caFile, fmt.Sprintf("%s:%d", host, port)) + client, password, err := ConnectToKeyServer(caFile, certFile, certKeyFile, fmt.Sprintf("%s:%d", host, port)) if err != nil { return err } @@ -228,7 +242,11 @@ func EraseKey() error { return errors.New(MSG_E_ERASE_NO_CONF) } caFile := sysconf.GetString(keyserv.CLIENT_CONF_CA, "") - client, password, err := ConnectToKeyServer(caFile, fmt.Sprintf("%s:%d", host, port)) + client, password, err := ConnectToKeyServer( + caFile, + sysconf.GetString(keyserv.CLIENT_CONF_CERT, ""), + sysconf.GetString(keyserv.CLIENT_CONF_CERT_KEY, ""), + fmt.Sprintf("%s:%d", host, port)) if err != nil { return err } diff --git a/command/server.go b/command/server.go index c2e876d..c3e5add 100644 --- a/command/server.go +++ b/command/server.go @@ -29,7 +29,7 @@ const ( ) // Interactively read server password from terminal, then use the password to ping RPC server. -func ConnectToKeyServer(caFile, keyServer string) (client *keyserv.CryptClient, password string, err error) { +func ConnectToKeyServer(caFile, certFile, keyFile, keyServer string) (client *keyserv.CryptClient, password string, err error) { sys.LockMem() serverAddr := keyServer port := keyserv.SRV_DEFAULT_PORT @@ -52,13 +52,17 @@ func ConnectToKeyServer(caFile, keyServer string) (client *keyserv.CryptClient, customCA = caFileContent } // Initialise client and test connectivity with the server - client, err = keyserv.NewCryptClient(serverAddr, port, customCA) + client, err = keyserv.NewCryptClient(serverAddr, port, customCA, certFile, keyFile) if err != nil { return nil, "", err } password = sys.InputPassword(true, "", "Enter key server's password (no echo)") fmt.Fprintf(os.Stderr, "Establishing connection to %s on port %d...\n", serverAddr, port) - if err := client.Ping(keyserv.PingRequest{Password: password}); err != nil { + salt, err := client.GetSalt() + if err != nil { + return nil, "", err + } + if err := client.Ping(keyserv.PingRequest{Password: keyserv.HashPassword(salt, password)}); err != nil { return nil, "", err } return @@ -229,6 +233,23 @@ Important notes for client computers: "Key database directory"); keyDBDir != "" { sysconf.Set(keyserv.SRV_CONF_KEYDB_DIR, keyDBDir) } + // Walk through client certificate verification settings + validateClient := sys.InputBool("Should clients present their certificate in order to access this server?") + sysconf.Set(keyserv.SRV_CONF_TLS_VALIDATE_CLIENT, validateClient) + if validateClient { + sysconf.Set(keyserv.SRV_CONF_TLS_CA, sys.Input(true, "", "PEM-encoded TLS certificate authority that will issue client certificates")) + } + // Walk through KMIP settings + useExternalKMIPServer := sys.InputBool("Should encryption keys be kept on a KMIP-compatible key management appliance?") + if useExternalKMIPServer { + sysconf.Set(keyserv.SRV_CONF_KMIP_SERVER_HOST, sys.Input(true, "", "KMIP server host name")) + sysconf.Set(keyserv.SRV_CONF_KMIP_SERVER_PORT, sys.InputInt(true, 5696, 1, 65535, "KMIP port number")) + sysconf.Set(keyserv.SRV_CONF_KMIP_SERVER_USER, sys.Input(false, "", "KMIP username")) + sysconf.Set(keyserv.SRV_CONF_KMIP_SERVER_PASS, sys.Input(false, "", "KMIP password")) + sysconf.Set(keyserv.SRV_CONF_KMIP_SERVER_TLS_CA, sys.Input(false, "", "PEM-encoded TLS certificate authority that issued KMIP server certificate")) + sysconf.Set(keyserv.SRV_CONF_KMIP_SERVER_TLS_CERT, sys.Input(false, "", "PEM-encoded TLS client identitiy certificate")) + sysconf.Set(keyserv.SRV_CONF_KMIP_SERVER_TLS_KEY, sys.Input(false, "", "PEM-encoded TLS client identitiy certificate key")) + } // Walk through optional email settings fmt.Println("\nTo enable Email notifications, enter the following parameters:") if mta := sys.Input(false, @@ -339,6 +360,7 @@ func KeyRPCDaemon() error { if err := srv.ListenRPC(); err != nil { return fmt.Errorf("Failed to listen for connections: %v", err) } + srv.HandleConnections() // intentionally block here return nil } @@ -351,12 +373,13 @@ func ListKeys() error { recList := db.List() fmt.Printf("Total: %d records (date and time are in zone %s)\n", len(recList), time.Now().Format("MST")) // Print mount point last, making output possible to be parsed by a program - // Max field length: 15 (IP), 19 (IP When), 36 (UUID), 9 (Max Active), 9 (Current Active) last field (mount point) - fmt.Println("Used By When UUID Max.Users Num.Users Mount Point") + // Max field length: 15 (IP), 19 (IP When), 12(ID), 36 (UUID), 9 (Max Active), 9 (Current Active) last field (mount point) + fmt.Println("Used By When ID UUID Max.Users Num.Users Mount Point") for _, rec := range recList { outputTime := time.Unix(rec.LastRetrieval.Timestamp, 0).Format(TIME_OUTPUT_FORMAT) rec.RemoveDeadHosts() - fmt.Printf("%-15s %-19s %-36s %-9s %-9s %s\n", rec.LastRetrieval.IP, outputTime, rec.UUID, + fmt.Printf("%-15s %-19s %-12s %-36s %-9s %-9s %s\n", rec.LastRetrieval.IP, outputTime, + rec.ID, rec.UUID, strconv.Itoa(rec.MaxActive), strconv.Itoa(len(rec.AliveMessages)), rec.MountPoint) } return nil @@ -391,7 +414,7 @@ func EditKey(uuid string) error { rec.AliveCount = roundedAliveTimeout / routine.REPORT_ALIVE_INTERVAL_SEC } // Write record file and restart server to let it reload all records into memory - if err := db.Upsert(rec); err != nil { + if _, err := db.Upsert(rec); err != nil { return fmt.Errorf("Failed to update record - %v", err) } fmt.Println("Record has been updated successfully.") diff --git a/fs/crypt.go b/fs/crypt.go index 79bec33..bf18ee8 100644 --- a/fs/crypt.go +++ b/fs/crypt.go @@ -21,13 +21,13 @@ const ( ) // Call cryptsetup luksFormat on the block device node. -func CryptFormat(key []byte, blockDev string) error { +func CryptFormat(key []byte, blockDev, uuid string) error { if err := CheckBlockDevice(blockDev); err != nil { return err } _, stdout, stderr, err := sys.Exec(bytes.NewReader(key), nil, nil, BIN_CRYPTSETUP, "--batch-mode", "--cipher", LUKS_CIPHER, "--hash", LUKS_HASH, "--key-size", LUKS_KEY_SIZE_S, - "luksFormat", "--key-file=-", blockDev) + "luksFormat", "--key-file=-", blockDev, "--uuid="+uuid) if err != nil { return fmt.Errorf("CryptFormat: failed to format \"%s\" - %v %s %s", blockDev, err, stdout, stderr) } @@ -133,7 +133,7 @@ func CryptStatus(name string) (mapping CryptMapping, err error) { _, stdout, _, _ := sys.Exec(nil, nil, nil, BIN_CRYPTSETUP, "status", name) mapping = ParseCryptStatus(stdout) if !mapping.IsValid() { - err = fmt.Errorf("CryptStatus: failed to retrieve a valid output for \"%s\"", name) + err = fmt.Errorf("CryptStatus: failed to retrieve a valid output for \"%s\", gathered information is: %+v", name, mapping) } return } diff --git a/fs/crypt_test.go b/fs/crypt_test.go index 2a33764..65aa44b 100644 --- a/fs/crypt_test.go +++ b/fs/crypt_test.go @@ -6,7 +6,7 @@ import "testing" // The unit test simply makes sure that the functions do not crash, it does not set up an encrypted device node. func TestCryptSetup(t *testing.T) { - if err := CryptFormat([]byte{}, "doesnotexist"); err == nil { + if err := CryptFormat([]byte{}, "doesnotexist", "testuuid"); err == nil { t.Fatal("did not error") } if err := CryptOpen([]byte{}, "doesnotexist", "doesnotexist"); err == nil { diff --git a/keydb/db.go b/keydb/db.go index 491f424..aaadafe 100644 --- a/keydb/db.go +++ b/keydb/db.go @@ -51,7 +51,7 @@ Caller should consider ot lock memory. */ func OpenDBOneRecord(dir, recordUUID string) (db *DB, err error) { if err := os.MkdirAll(dir, DB_DIR_FILE_MODE); err != nil { - return nil, fmt.Errorf("OpenDB: failed to make db directory \"%s\" - %v", dir, err) + return nil, fmt.Errorf("OpenDBOneRecord: failed to make db directory \"%s\" - %v", dir, err) } db = &DB{Dir: dir, Lock: new(sync.RWMutex), RecordsByUUID: map[string]Record{}, RecordsByID: map[string]Record{}} keyRecord, err := db.LoadRecord(path.Join(dir, recordUUID)) @@ -81,7 +81,7 @@ func (db *DB) ReloadDB() error { db.RecordsByID = make(map[string]Record) keyFiles, err := ioutil.ReadDir(db.Dir) if err != nil { - return fmt.Errorf("ReloadDB: failed to read directory \"%s\" - %v", db.Dir, err) + return fmt.Errorf("DB.ReloadDB: failed to read directory \"%s\" - %v", db.Dir, err) } var lastSequenceNum int64 @@ -273,11 +273,11 @@ func (db *DB) Erase(uuid string) error { db.Lock.Lock() defer db.Lock.Unlock() if _, exists := db.RecordsByUUID[uuid]; !exists { - return fmt.Errorf("DB.Delete: record '%s' does not exist", uuid) + return fmt.Errorf("DB.Erase: record '%s' does not exist", uuid) } delete(db.RecordsByUUID, uuid) if err := fs.SecureErase(path.Join(db.Dir, uuid), true); err != nil { - return fmt.Errorf("DB.Delete: failed to delete db record for %s - %v", uuid, err) + return fmt.Errorf("DB.Erase: failed to delete db record for %s - %v", uuid, err) } return nil } diff --git a/keydb/db_test.go b/keydb/db_test.go index c49fb10..9b29f64 100644 --- a/keydb/db_test.go +++ b/keydb/db_test.go @@ -50,17 +50,17 @@ func TestRecordCRUD(t *testing.T) { rec2Alive := rec2 rec2Alive.LastRetrieval = aliveMsg rec2Alive.AliveMessages = map[string][]AliveMessage{aliveMsg.IP: []AliveMessage{aliveMsg}} - if seq, err := db.Upsert(rec1); err != nil || seq != 1 { + if seq, err := db.Upsert(rec1); err != nil || seq != "1" { t.Fatal(err, seq) } - if seq, err := db.Upsert(rec2); err != nil || seq != 2 { + if seq, err := db.Upsert(rec2); err != nil || seq != "2" { t.Fatal(err, seq) } // Match sequence number in my copy of records with their should-be ones - rec1.SequenceNum = 1 - rec1Alive.SequenceNum = 1 - rec2.SequenceNum = 2 - rec2Alive.SequenceNum = 2 + rec1.ID = "1" + rec1Alive.ID = "1" + rec2.ID = "2" + rec2Alive.ID = "2" // Select one record and then select both records if found, rejected, missing := db.Select(aliveMsg, true, "1", "doesnotexist"); !reflect.DeepEqual(found, map[string]Record{rec1.UUID: rec1Alive}) || !reflect.DeepEqual(rejected, []string{}) || @@ -89,7 +89,7 @@ func TestRecordCRUD(t *testing.T) { if len(db.RecordsByUUID["1"].AliveMessages["ip1"]) != 2 || len(db.RecordsByUUID["2"].AliveMessages["ip1"]) != 2 { t.Fatal(db.RecordsByUUID) } - if len(db.RecordsByID[1].AliveMessages["ip1"]) != 2 || len(db.RecordsByID[2].AliveMessages["ip1"]) != 2 { + if len(db.RecordsByID["1"].AliveMessages["ip1"]) != 2 || len(db.RecordsByID["2"].AliveMessages["ip1"]) != 2 { t.Fatal(db.RecordsByUUID) } // Erase a record @@ -135,7 +135,7 @@ func TestOpenDBOneRecord(t *testing.T) { }, AliveMessages: make(map[string][]AliveMessage), } - if seq, err := db.Upsert(rec); err != nil || seq != 1 { + if seq, err := db.Upsert(rec); err != nil || seq != "1" { t.Fatal(err) } dbOneRecord, err := OpenDBOneRecord(TEST_DIR, "a") @@ -145,17 +145,17 @@ func TestOpenDBOneRecord(t *testing.T) { if len(dbOneRecord.RecordsByUUID) != 1 { t.Fatal(dbOneRecord.RecordsByUUID) } - rec.SequenceNum = 1 + rec.ID = "1" if recA, found := dbOneRecord.GetByUUID("a"); !found || !reflect.DeepEqual(recA, rec) { t.Fatal(recA, found) } - if recA, found := dbOneRecord.GetByID(1); !found || !reflect.DeepEqual(recA, rec) { + if recA, found := dbOneRecord.GetByID("1"); !found || !reflect.DeepEqual(recA, rec) { t.Fatal(recA, found) } if _, found := dbOneRecord.GetByUUID("doesnotexist"); found { t.Fatal("false positive") } - if _, found := dbOneRecord.GetByID(78598123); found { + if _, found := dbOneRecord.GetByID("78598123"); found { t.Fatal("false positive") } } @@ -209,18 +209,18 @@ func TestList(t *testing.T) { } rec3NoKey := rec3 rec3NoKey.Key = nil - if seq, err := db.Upsert(rec1); err != nil || seq != 1 { + if seq, err := db.Upsert(rec1); err != nil || seq != "1" { t.Fatal(err, seq) } - if seq, err := db.Upsert(rec2); err != nil || seq != 2 { + if seq, err := db.Upsert(rec2); err != nil || seq != "2" { t.Fatal(err) } - if seq, err := db.Upsert(rec3); err != nil || seq != 3 { + if seq, err := db.Upsert(rec3); err != nil || seq != "3" { t.Fatal(err) } - rec1NoKey.SequenceNum = 1 - rec2NoKey.SequenceNum = 2 - rec3NoKey.SequenceNum = 3 + rec1NoKey.ID = "1" + rec2NoKey.ID = "2" + rec3NoKey.ID = "3" recs := db.List() if !reflect.DeepEqual(recs[0], rec1NoKey) || !reflect.DeepEqual(recs[1], rec3NoKey) || diff --git a/keydb/record.go b/keydb/record.go index f742690..59e22d5 100644 --- a/keydb/record.go +++ b/keydb/record.go @@ -27,17 +27,16 @@ When stored on disk, the record resides in a file encoded in gob. The binary encoding method is intentionally chosen to deter users from manually editing the files on disk. */ type Record struct { - UUID string // partition UUID ID string // KMIP key ID - KeyOnKMIP bool // true only if key content is located on an external KMIP server - CreationTime time.Time // creation time Version int // Record version + CreationTime time.Time // creation time Key []byte // encryption key in plain form - MountPoint string // mount point on client computer - MountOptions []string // file system's mount options + UUID string // file system uuid + MountPoint string // mount point of the file system + MountOptions []string // mount options of the file system MaxActive int // maximum allowed active key users (computers), set to <=0 to allow unlimited. LastRetrieval AliveMessage // the most recent host who retrieved this key - AliveIntervalSec int // interval in seconds at which all client computers holding this key must report their liveness + AliveIntervalSec int // interval in seconds at which all user of the file system holding this key must report they're online AliveCount int // a client computer is considered dead after missing so many alive messages AliveMessages map[string][]AliveMessage // recent alive messages (latest is last), string map key is the host IP as seen by this server. } @@ -170,11 +169,12 @@ func (rec *Record) Deserialise(in []byte) error { // Format all attributes (except the binary key) for pretty printing, using the specified separator. func (rec *Record) FormatAttrs(separator string) string { - return fmt.Sprintf(`Timestamp="%d"%sIP="%s"%sHostname="%s"%sFileSystemUUID="%s"%sMountPoint="%s"%sMountOptions="%s"`, + return fmt.Sprintf(`Timestamp="%d"%sIP="%s"%sHostname="%s"%sFileSystemUUID="%s"%sKMIPID="%s"%sMountPoint="%s"%sMountOptions="%s"`, rec.LastRetrieval.Timestamp, separator, rec.LastRetrieval.IP, separator, rec.LastRetrieval.Hostname, separator, rec.UUID, separator, + rec.ID, separator, strings.Replace(rec.MountPoint, `"`, `\"`, -1), separator, rec.GetMountOptionStr()) } diff --git a/keydb/record_test.go b/keydb/record_test.go index 4f5aac1..f281ebc 100644 --- a/keydb/record_test.go +++ b/keydb/record_test.go @@ -325,6 +325,7 @@ func TestRecordAliveMessage2(t *testing.T) { func TestRecord(t *testing.T) { rec := Record{ UUID: "testuuid", + ID: "testid", Key: []byte{0, 1, 2, 3}, MountPoint: "/tmp/a", MountOptions: []string{"rw", "noatime"}, @@ -366,7 +367,7 @@ func TestRecord(t *testing.T) { } // Format as string - if s := rec.FormatAttrs("|"); s != `Timestamp="123456"|IP="ip1"|Hostname="host1"|FileSystemUUID="testuuid"|MountPoint="/tmp/a"|MountOptions="rw,noatime"` { + if s := rec.FormatAttrs("|"); s != `Timestamp="123456"|IP="ip1"|Hostname="host1"|FileSystemUUID="testuuid"|KMIPID="testid"|MountPoint="/tmp/a"|MountOptions="rw,noatime"` { t.Fatal(s) } } diff --git a/keyserv/kmip_client.go b/keyserv/kmip_client.go index 01f0969..c586d92 100644 --- a/keyserv/kmip_client.go +++ b/keyserv/kmip_client.go @@ -7,7 +7,6 @@ import ( "encoding/binary" "errors" "fmt" - "github.com/HouzuoGuo/cryptctl/fs" "github.com/HouzuoGuo/cryptctl/kmip/structure" "github.com/HouzuoGuo/cryptctl/kmip/ttlv" "io" @@ -21,7 +20,8 @@ const ( Both server and client refuse to accept a structure larger than this number. The number is reasonable and big enough for all three operations supported by server and client: create, get, and destroy. */ - MaxKMIPStructLen = 65536 + MaxKMIPStructLen = 65536 + KMIPAESKeySizeBits = 256 // The only kind of AES encryption key the KMIP server and client will expect to use ) /* @@ -48,7 +48,7 @@ func NewKMIPClient(host string, port int, username, password string, caCertPEM [ Password: password, TLSConfig: new(tls.Config), } - if caCertPEM != nil || len(caCertPEM) > 0 { + if caCertPEM != nil && len(caCertPEM) > 0 { // Use custom CA caCertPool := x509.NewCertPool() if !caCertPool.AppendCertsFromPEM(caCertPEM) { @@ -82,8 +82,12 @@ func ReadFullTTLV(reader io.Reader) (ttlv.Item, error) { } // Read remainder of request structure ttlValue = make([]byte, structLen) - if _, err := reader.Read(ttlValue); err != nil { - return nil, err + if n, err := reader.Read(ttlValue); err != nil { + if int32(n) == structLen { + err = nil + } else { + return nil, err + } } // Assemble TTL header and structure value for deserialisation fullTTLV = make([]byte, 8+len(ttlValue)) @@ -105,7 +109,9 @@ func (client *KMIPClient) MakeRequest(request structure.SerialisedItem) (structu return nil, fmt.Errorf("KMIPClient.MakeRequest: connection to \"%s\" failed - %v", addr, err) } defer conn.Close() - if _, err = conn.Write(ttlv.EncodeAny(request.SerialiseToTTLV())); err != nil { + serialisedRequest := request.SerialiseToTTLV() + encodedRequest := ttlv.EncodeAny(serialisedRequest) + if _, err = conn.Write(encodedRequest); err != nil { return nil, fmt.Errorf("KMIPClient.MakeRequest: failed to write request - %v", err) } ttlvResp, err := ReadFullTTLV(conn) @@ -163,11 +169,13 @@ func ResponseItemToError(resp structure.SResponseBatchItem) error { func (client *KMIPClient) CreateKey(keyName string) (id string, err error) { defer func() { // In the unlikely case that a misbehaving server causes client to crash. - if r := recover(); r != nil { - msg := fmt.Sprintf("KMIPClient.CreateKey: the function crashed due to programming error - %v", r) - log.Print(msg) - err = errors.New(msg) - } + /* + if r := recover(); r != nil { + msg := fmt.Sprintf("KMIPClient.CreateKey: the function crashed due to programming error - %v", r) + log.Print(msg) + err = errors.New(msg) + } + */ }() resp, err := client.MakeRequest(&structure.SCreateRequest{ SRequestHeader: client.GetRequestHeader(), @@ -179,15 +187,24 @@ func (client *KMIPClient) CreateKey(keyName string) (id string, err error) { Attributes: []structure.SAttribute{ { TAttributeName: ttlv.Text{Value: structure.ValAttributeNameCryptoAlg}, - AttributeValue: &ttlv.Enumeration{Value: structure.ValCryptoAlgoAES}, + AttributeValue: &ttlv.Enumeration{ + TTL: ttlv.TTL{Tag: structure.TagAttributeValue}, + Value: structure.ValCryptoAlgoAES, + }, }, { TAttributeName: ttlv.Text{Value: structure.ValAttributeNameCryptoLen}, - AttributeValue: &ttlv.Enumeration{Value: fs.LUKS_KEY_SIZE_I}, + AttributeValue: &ttlv.Integer{ + TTL: ttlv.TTL{Tag: structure.TagAttributeValue}, + Value: KMIPAESKeySizeBits, // keep in mind that key size is in bits + }, }, { TAttributeName: ttlv.Text{Value: structure.ValAttributeNameCryptoUsageMask}, - AttributeValue: &ttlv.Enumeration{Value: structure.MaskCryptoUsageEncrypt | structure.MaskCryptoUsageDecrypt}, + AttributeValue: &ttlv.Integer{ + TTL: ttlv.TTL{Tag: structure.TagAttributeValue}, + Value: structure.MaskCryptoUsageEncrypt | structure.MaskCryptoUsageDecrypt, + }, }, { TAttributeName: ttlv.Text{Value: structure.ValAttributeNameKeyName}, @@ -220,7 +237,7 @@ func (client *KMIPClient) GetKey(id string) (key []byte, err error) { defer func() { // In the unlikely case that a misbehaving server causes client to crash. if r := recover(); r != nil { - msg := fmt.Sprintf("KMIPClient.GetKey: (ID %d) the function crashed due to programming error - %v", id, r) + msg := fmt.Sprintf("KMIPClient.GetKey: (ID %s) the function crashed due to programming error - %v", id, r) log.Print(msg) err = errors.New(msg) } @@ -242,8 +259,8 @@ func (client *KMIPClient) GetKey(id string) (key []byte, err error) { return } key = typedResp.SResponseBatchItem.SResponsePayload.(*structure.SResponsePayloadGet).SSymmetricKey.SKeyBlock.SKeyValue.BKeyMaterial.Value - if key == nil || len(key) != fs.LUKS_KEY_SIZE_I { - err = errors.New("KMIPClient.GetKey: (ID %d) key content looks wrong") + if key == nil || len(key) != KMIPAESKeySizeBits/8 { + err = fmt.Errorf("KMIPClient.GetKey: (ID %s) key content looks wrong (%d)", id, len(key)) } return } @@ -253,7 +270,7 @@ func (client *KMIPClient) DestroyKey(id string) (err error) { defer func() { // In the unlikely case that a misbehaving server causes client to crash. if r := recover(); r != nil { - msg := fmt.Sprintf("KMIPClient.GetKey: (ID %d) the function crashed due to programming error - %v", id, r) + msg := fmt.Sprintf("KMIPClient.GetKey: (ID %s) the function crashed due to programming error - %v", id, r) log.Print(msg) err = errors.New(msg) } @@ -270,6 +287,6 @@ func (client *KMIPClient) DestroyKey(id string) (err error) { if err != nil { return } - err = ResponseItemToError(resp.(*structure.SGetResponse).SResponseBatchItem) + err = ResponseItemToError(resp.(*structure.SDestroyResponse).SResponseBatchItem) return } diff --git a/keyserv/kmip_server.go b/keyserv/kmip_server.go index 2884579..3b7ad54 100644 --- a/keyserv/kmip_server.go +++ b/keyserv/kmip_server.go @@ -7,7 +7,6 @@ import ( "encoding/hex" "errors" "fmt" - "github.com/HouzuoGuo/cryptctl/fs" "github.com/HouzuoGuo/cryptctl/keydb" "github.com/HouzuoGuo/cryptctl/kmip/structure" "github.com/HouzuoGuo/cryptctl/kmip/ttlv" @@ -17,19 +16,19 @@ import ( "time" ) +const ( + LenKMIPRandomPass = 256 // length of random password validated by KMIP server +) + // Create a new disk encryption key out of entropy from cryptographic random pool. -func GetNewDiskEncryptionKey() []byte { - random := make([]byte, fs.LUKS_KEY_SIZE_I) +func GetNewDiskEncryptionKeyBits() []byte { + random := make([]byte, KMIPAESKeySizeBits/8) if _, err := rand.Read(random); err != nil { - log.Fatalf("GetNewDiskEncryptionKey: system is out of entropy - %v", err) + log.Fatalf("GetNewDiskEncryptionKeyBits: system is out of entropy - %v", err) } return random } -const ( - LenKMIPRandomPass = 256 // length of random password validated by KMIP server -) - /* A partially implemented KMIP protocol server that creates and serves encryption keys upon request. The implementation is specifically tailored to the requirements of RPC server, hence it does not validate client certificate and only relies @@ -55,27 +54,33 @@ func NewKMIPServer(db *keydb.DB, certFilePath, certKeyPath string) (*KMIPServer, return server, nil } -// Start KMIP server and process incoming requests, block caller until listener is told to shut down. -func (srv *KMIPServer) Listen() (err error) { +// Start KMIP server's listener. +func (srv *KMIPServer) Listen() error { // KMIP challenge is similar to shutdown challenge, but it is encoded into string. randPass := make([]byte, LenKMIPRandomPass) - if _, err = rand.Read(randPass); err != nil { - return + if _, err := rand.Read(randPass); err != nil { + return err } // In protocol specification, KMIP authentication password has to be a text string, hence the password challenge is encoded here. srv.PasswordChallenge = []byte(hex.EncodeToString(randPass)) - srv.Listener, err = tls.Listen("tcp", "127.0.0.1:0", srv.TLSConfig) + var err error + srv.Listener, err = tls.Listen("tcp", "localhost:0", srv.TLSConfig) if err != nil { - return + return err } log.Printf("KMIPServer.Listen: listening on 127.0.0.1:%d", srv.GetPort()) + return nil +} + +// Process incoming KMIP requests, block caller until listener is told to shut down. +func (srv *KMIPServer) HandleConnections() { for { - incoming, err := srv.Listener.Accept() + conn, err := srv.Listener.Accept() if err != nil { log.Printf("KMIPServer.Listen: quit now - %v", err) - return nil + return } - go srv.HandleConnection(incoming) + go srv.HandleConnection(conn) } } @@ -84,28 +89,39 @@ func (srv *KMIPServer) GetPort() int { return srv.Listener.Addr().(*net.TCPAddr).Port } +// Close listener and shutdown service. +func (srv *KMIPServer) Shutdown() { + if listener := srv.Listener; listener != nil { + srv.Listener.Close() + } +} + /* Converse with KMIP client. This KMIP server is made only for cryptctl's own KMIP client, hence a lot of validation work are skipped intentionally. Normally KMIP service is capable of handling more than one requests per connection, but cryptctl's own KMIP client only submits one request per connection. */ -func (srv *KMIPServer) HandleConnection(client net.Conn) { +func (srv *KMIPServer) HandleConnection(conn net.Conn) { defer func() { /* In the unlikely case that user connects a fully featured KMIP client to this server and the client unexpectedly triggers a buffer handling issue, the error is logged here and then ignored. */ if r := recover(); r != nil { - log.Printf("KMIPServer.HandleConnection: panic occured with client %s - %v", client.RemoteAddr().String(), r) + log.Printf("KMIPServer.HandleConnection: panic occured with client %s - %v", conn.RemoteAddr().String(), r) } }() - defer client.Close() + defer conn.Close() var err error var successfulDecodeAttempt structure.SerialisedItem decodeAttempts := []structure.SerialisedItem{&structure.SCreateRequest{}, &structure.SGetRequest{}, &structure.SDestroyRequest{}} - log.Printf("KMIPServer.HandleConnection: connected from %s", client.RemoteAddr().String()) - ttlvItem, err := ReadFullTTLV(client) + log.Printf("KMIPServer.HandleConnection: connected from %s", conn.RemoteAddr().String()) + ttlvItem, err := ReadFullTTLV(conn) + if err != nil { + log.Printf("KMIPServer.HandleConnection: IO failure occured with client %s - %v", conn.RemoteAddr().String(), err) + return + } // Try decoding request into request structures and see which one succeeds for _, attempt := range decodeAttempts { if decodeErr := attempt.DeserialiseFromTTLV(ttlvItem); decodeErr == nil { @@ -114,14 +130,15 @@ func (srv *KMIPServer) HandleConnection(client net.Conn) { } } if successfulDecodeAttempt == nil { - err = fmt.Errorf("Server does not understand request:\n%s", ttlv.DebugTTLVItem(0, ttlvItem)) + err = fmt.Errorf("Server does not understand request from client %s", conn.RemoteAddr()) goto Error } - if err = srv.HandleRequest(successfulDecodeAttempt, client); err != nil { + if err = srv.HandleRequest(successfulDecodeAttempt, conn); err != nil { goto Error } + return Error: - log.Printf("KMIPServer.HandleConnection: error occured with client %s - %v", client.RemoteAddr().String(), err) + log.Printf("KMIPServer.HandleConnection: error occured with client %s - %v", conn.RemoteAddr().String(), err) } // Try to match KMIP request's password with server's challenge. If there is a mismatch, return an error. @@ -137,7 +154,7 @@ func (srv *KMIPServer) CheckPassword(header structure.SRequestHeader) error { /* Handle a KMIP request, produce a response structure and send it back to client. */ -func (srv *KMIPServer) HandleRequest(req structure.SerialisedItem, client net.Conn) (err error) { +func (srv *KMIPServer) HandleRequest(req structure.SerialisedItem, conn net.Conn) (err error) { defer func() { /* The KMIP request only comes from cryptctl program itself, so all type assertions and attributes @@ -150,35 +167,23 @@ func (srv *KMIPServer) HandleRequest(req structure.SerialisedItem, client net.Co var resp structure.SerialisedItem switch t := req.(type) { case *structure.SCreateRequest: - if err = srv.CheckPassword(t.SRequestHeader); err != nil { - return - } - resp, err = srv.HandleCreateRequest(t) - if err != nil { - return + if err = srv.CheckPassword(t.SRequestHeader); err == nil { + resp, err = srv.HandleCreateRequest(t) } case *structure.SGetRequest: - if err := srv.CheckPassword(t.SRequestHeader); err != nil { - return err - } - resp, err = srv.HandleGetRequest(t) - if err != nil { - return + if err := srv.CheckPassword(t.SRequestHeader); err == nil { + resp, err = srv.HandleGetRequest(t) } case *structure.SDestroyRequest: - if err := srv.CheckPassword(t.SRequestHeader); err != nil { - return err - } - resp, err = srv.HandleDestroyRequest(t) - if err != nil { - return + if err := srv.CheckPassword(t.SRequestHeader); err == nil { + resp, err = srv.HandleDestroyRequest(t) } default: - return fmt.Errorf("KMIPServer.HandleRequest: unknown request type %s", reflect.TypeOf(req).String()) + err = fmt.Errorf("KMIPServer.HandleRequest: unknown request type %s", reflect.TypeOf(req).String()) } - client.SetWriteDeadline(time.Now().Add(KMIPTimeoutSec * time.Second)) - if _, err = client.Write(ttlv.EncodeAny(resp.SerialiseToTTLV())); err != nil { - return err + if err == nil { + conn.SetWriteDeadline(time.Now().Add(KMIPTimeoutSec * time.Second)) + _, err = conn.Write(ttlv.EncodeAny(resp.SerialiseToTTLV())) } return } @@ -188,7 +193,7 @@ func (srv *KMIPServer) HandleCreateRequest(req *structure.SCreateRequest) (*stru var keyName string for _, attr := range req.SRequestBatchItem.SRequestPayload.(*structure.SRequestPayloadCreate).STemplateAttribute.Attributes { if attr.TAttributeName.Value == structure.ValAttributeNameKeyName { - keyName = attr.AttributeValue.(*ttlv.Structure).Items[0].(ttlv.Text).Value + keyName = attr.AttributeValue.(*ttlv.Structure).Items[0].(*ttlv.Text).Value } } /* @@ -199,8 +204,9 @@ func (srv *KMIPServer) HandleCreateRequest(req *structure.SCreateRequest) (*stru kmipID, err := srv.DB.Upsert(keydb.Record{ UUID: keyName, CreationTime: creationTime, - Key: GetNewDiskEncryptionKey(), + Key: GetNewDiskEncryptionKeyBits(), }) + log.Printf("KMIPServer.HandleCreateRequest: just created a key named \"%s\" ID \"%s\"", keyName, kmipID) if err != nil { return nil, err } diff --git a/keyserv/kmip_server_test.go b/keyserv/kmip_server_test.go new file mode 100644 index 0000000..3a4f94e --- /dev/null +++ b/keyserv/kmip_server_test.go @@ -0,0 +1,14 @@ +package keyserv + +import ( + "reflect" + "testing" +) + +func TestGetNewDiskEncryptionKeyBits(t *testing.T) { + key1 := GetNewDiskEncryptionKeyBits() + key2 := GetNewDiskEncryptionKeyBits() + if len(key1) != KMIPAESKeySizeBits/8 || reflect.DeepEqual(key1, key2) { + t.Fatal(key1, key2) + } +} diff --git a/keyserv/kmip_test.go b/keyserv/kmip_test.go index e46e690..80e70bf 100644 --- a/keyserv/kmip_test.go +++ b/keyserv/kmip_test.go @@ -1,10 +1,12 @@ package keyserv import ( + "encoding/hex" "github.com/HouzuoGuo/cryptctl/keydb" "io/ioutil" "os" "path" + "reflect" "testing" "time" ) @@ -21,26 +23,128 @@ func TestKMIP(t *testing.T) { } var server *KMIPServer + var serverHasShutdown bool + server, err = NewKMIPServer(db, path.Join(PkgInGopath, "keyserv", "rpc_test.crt"), path.Join(PkgInGopath, "keyserv", "rpc_test.key")) + if err != nil { + t.Fatal(err) + } + if server.Listen(); err != nil { + t.Fatal(err) + } go func() { - var err error - server, err = NewKMIPServer(db, path.Join(PkgInGopath, "keyserv", "rpc_test.crt"), path.Join(PkgInGopath, "keyserv", "rpc_test.key")) - if err != nil { - t.Fatal(err) - } - if err := server.Listen(); err != nil { - t.Fatal(err) - } + server.HandleConnections() + serverHasShutdown = true }() + caCert, err := ioutil.ReadFile(path.Join(PkgInGopath, "keyserv", "rpc_test.crt")) + if err != nil { + t.Fatal(err) + } // Expect server to start in a second time.Sleep(1 * time.Second) - client, err := NewKMIPClient("127.0.0.1", server.GetPort(), "", string(server.PasswordChallenge), nil, + client, err := NewKMIPClient("localhost", server.GetPort(), "username-does-not-matter", string(server.PasswordChallenge), caCert, path.Join(PkgInGopath, "keyserv", "rpc_test.crt"), path.Join(PkgInGopath, "keyserv", "rpc_test.key")) - // In case the test certificate name does not match 127.0.0.1 - client.TLSConfig.InsecureSkipVerify = true if err != nil { t.Fatal(err) } - if id, err := client.CreateKey("testname"); err != nil || id != "1" { + // Create two keys + if id, err := client.CreateKey("test key 1"); err != nil || id != "1" { t.Fatal(err, id) } + if id, err := client.CreateKey("test key 2"); err != nil || id != "2" { + t.Fatal(err, id) + } + // Retrieve both keys and non-existent key + received1, err := client.GetKey("1") + if err != nil { + t.Fatal(err) + } + received2, err := client.GetKey("2") + if err != nil { + t.Fatal(err) + } + if _, err := client.GetKey("does not exist"); err == nil { + t.Fatal("did not error") + } + if reflect.DeepEqual(received1, received2) || len(received1) == 0 { + t.Fatal(hex.Dump(received1), hex.Dump(received2)) + } + // Destroy and retrieve again + if err := client.DestroyKey("1"); err != nil { + t.Fatal(err) + } + if err := client.DestroyKey("does not exist"); err == nil { + t.Fatal("did not error") + } + // Expect server to shut down within a second + server.Shutdown() + time.Sleep(1 * time.Second) + if !serverHasShutdown { + t.Fatal("did not shutdown") + } + // Calling shutdown multiple times should not cause panic + server.Shutdown() + server.Shutdown() +} + +func TestKMIPAgainstPyKMIP(t *testing.T) { + /* + A PyKMIP server can be started using the python code below: + import time + from kmip.services.server import * + + server = KmipServer( + hostname='0.0.0.0', + port=5696, + certificate_path='/etc/pykmip/server.crt', + key_path='/etc/pykmip/server.key', + ca_path='/etc/pykmip/ca.crt', + auth_suite='Basic', + config_path=None, + log_path='/etc/pykmip/server.log', + policy_path='/etc/pykmip/policy.json' + ) + + print("server about to start") + server.start() + print("server started") + server.serve() + print("connection served") + time.sleep(100) + */ + t.Skip("Start PyKMIP server manually and remove this skip statement to run this test case") + client, err := NewKMIPClient("127.0.0.1", 5696, "testuser", "testpass", nil, "/etc/pykmip/client.crt", "/etc/pykmip/client.key") + client.TLSConfig.InsecureSkipVerify = true + if err != nil { + t.Fatal(err) + } + // Create two keys + var id1, id2 string + if id1, err = client.CreateKey("test key 1"); err != nil || id1 == "" { + t.Fatal(err, id1) + } + if id2, err = client.CreateKey("test key 2"); err != nil || id2 == "" { + t.Fatal(err, id2) + } + // Retrieve both keys and non-existent key + received1, err := client.GetKey(id1) + if err != nil { + t.Fatal(err) + } + received2, err := client.GetKey(id2) + if err != nil { + t.Fatal(err) + } + if _, err := client.GetKey("does not exist"); err == nil { + t.Fatal("did not error") + } + if reflect.DeepEqual(received1, received2) || len(received1) == 0 { + t.Fatal(hex.Dump(received1), hex.Dump(received2)) + } + // Destroy and retrieve again + if err := client.DestroyKey(id1); err != nil { + t.Fatal(err) + } + if err := client.DestroyKey("does not exist"); err == nil { + t.Fatal("did not error") + } } diff --git a/keyserv/rpc_client.go b/keyserv/rpc_client.go index 311d873..2bbb2d3 100644 --- a/keyserv/rpc_client.go +++ b/keyserv/rpc_client.go @@ -23,12 +23,16 @@ const ( CLIENT_CONF_HOST = "KEY_SERVER_HOST" CLIENT_CONF_PORT = "KEY_SERVER_PORT" CLIENT_CONF_CA = "TLS_CA_PEM" + CLIENT_CONF_CERT = "TLS_CERT_PEM" + CLIENT_CONF_CERT_KEY = "TLS_CERT_KEY_PEM" TEST_RPC_PASS = "pass" ) type CryptClient struct { ServerHost string ServerPort int + TLSCert string + TLSKey string TLSConfig *tls.Config } @@ -36,13 +40,13 @@ type CryptClient struct { Initialise an RPC client. The function does not immediately establish a connection to server, connection is only made along with each RPC call. */ -func NewCryptClient(host string, port int, caCertPEM []byte) (*CryptClient, error) { +func NewCryptClient(host string, port int, caCertPEM []byte, certPath, certKeyPath string) (*CryptClient, error) { client := &CryptClient{ ServerHost: host, ServerPort: port, TLSConfig: new(tls.Config), } - if caCertPEM != nil || len(caCertPEM) > 0 { + if caCertPEM != nil && len(caCertPEM) > 0 { // Use custom CA caCertPool := x509.NewCertPool() if !caCertPool.AppendCertsFromPEM(caCertPEM) { @@ -50,6 +54,15 @@ func NewCryptClient(host string, port int, caCertPEM []byte) (*CryptClient, erro } client.TLSConfig.RootCAs = caCertPool } + if certPath != "" { + // Tell client to present its identity to server + clientCert, err := tls.LoadX509KeyPair(certPath, certKeyPath) + if err != nil { + return nil, err + } + client.TLSConfig.Certificates = []tls.Certificate{clientCert} + } + client.TLSConfig.BuildNameToCertificate() return client, nil } @@ -71,7 +84,7 @@ func NewCryptClientFromSysconfig(sysconf *sys.Sysconfig) (*CryptClient, error) { return nil, fmt.Errorf("NewCryptClientFromSysconfig: failed to read CA PEM file at \"%s\" - %v", ca, err) } } - return NewCryptClient(host, port, caCertPEM) + return NewCryptClient(host, port, caCertPEM, sysconf.GetString(CLIENT_CONF_CERT, ""), sysconf.GetString(CLIENT_CONF_CERT_KEY, "")) } /* @@ -97,6 +110,15 @@ func (client *CryptClient) DoRPC(fun func(*rpc.Client) error) error { return nil } +// Retrieve the salt that was used to hash server's access password. +func (client *CryptClient) GetSalt() (salt PasswordSalt, err error) { + err = client.DoRPC(func(rpcClient *rpc.Client) error { + var dummy DummyAttr + return rpcClient.Call(fmt.Sprintf(RPCObjNameFmt, "GetSalt"), &dummy, &salt) + }) + return +} + // Ping RPC server. Return an error if there is a communication mishap or server has not undergone the initial setup. func (client *CryptClient) Ping(req PingRequest) error { return client.DoRPC(func(rpcClient *rpc.Client) error { @@ -105,12 +127,12 @@ func (client *CryptClient) Ping(req PingRequest) error { }) } -// Save a new key record. -func (client *CryptClient) SaveKey(req SaveKeyReq) error { - return client.DoRPC(func(rpcClient *rpc.Client) error { - var dummy DummyAttr - return rpcClient.Call(fmt.Sprintf(RPCObjNameFmt, "SaveKey"), req, &dummy) +// Create a new key record. +func (client *CryptClient) CreateKey(req CreateKeyReq) (resp CreateKeyResp, err error) { + err = client.DoRPC(func(rpcClient *rpc.Client) error { + return rpcClient.Call(fmt.Sprintf(RPCObjNameFmt, "CreateKey"), req, &resp) }) + return } // Retrieve encryption keys without a password. @@ -165,10 +187,11 @@ func (client *CryptClient) Shutdown(req ShutdownReq) error { } // Start an RPC server in a testing configuration, return a client connected to the server and a teardown function. -func StartTestServer(tb testing.TB) (client *CryptClient, tearDown func(testing.TB)) { +func StartTestServer(tb testing.TB) (*CryptClient, func(testing.TB)) { keydbDir, err := ioutil.TempDir("", "cryptctl-rpctest") if err != nil { tb.Fatal(err) + return nil, nil } // Fill in configuration blanks (listen port is left at default) salt := NewSalt() @@ -185,34 +208,49 @@ func StartTestServer(tb testing.TB) (client *CryptClient, tearDown func(testing. srv, err := NewCryptServer(srvConf, Mailer{}) if err != nil { tb.Fatal(err) + return nil, nil } - tearDown = func(t testing.TB) { - if err := client.Shutdown(ShutdownReq{Challenge: srv.ShutdownChallenge}); err != nil { - t.Fatal(err) - } - if err := client.Ping(PingRequest{Password: TEST_RPC_PASS}); err == nil { - t.Fatal("server did not shutdown") - } - if err := os.RemoveAll(keydbDir); err != nil { - t.Fatal(err) - } + if err := srv.ListenRPC(); err != nil { + tb.Fatal(err) + return nil, nil } + go srv.HandleConnections() // The test certificate's CN is "localhost" caPath := path.Join(PkgInGopath, "keyrpc", "rpc_test.crt") certContent, err := ioutil.ReadFile(caPath) // Construct a client via function parameters - client, err = NewCryptClient("localhost", 3737, certContent) + client, err := NewCryptClient("localhost", 3737, certContent, "", "") if err != nil { tb.Fatal(err) + return nil, nil } - go srv.ListenRPC() + client.TLSConfig.InsecureSkipVerify = true // Server should start within about 2 seconds - for i := 0; i < 10; i++ { - if err := client.Ping(PingRequest{Password: TEST_RPC_PASS}); err == nil { - return + serverReady := false + for i := 0; i < 20; i++ { + if err := client.Ping(PingRequest{Password: HashPassword(salt, TEST_RPC_PASS)}); err == nil { + serverReady = true + break } time.Sleep(100 * time.Millisecond) } - tb.Fatal("server did not start in time") - return nil, nil + if !serverReady { + tb.Fatal("server did not start in time") + return nil, nil + } + tearDown := func(t testing.TB) { + if err := client.Shutdown(ShutdownReq{Challenge: srv.ShutdownChallenge}); err != nil { + t.Fatal(err) + return + } + if err := client.Ping(PingRequest{Password: HashPassword(salt, TEST_RPC_PASS)}); err == nil { + t.Fatal("server did not shutdown") + return + } + if err := os.RemoveAll(keydbDir); err != nil { + t.Fatal(err) + return + } + } + return client, tearDown } diff --git a/keyserv/rpc_client_test.go b/keyserv/rpc_client_test.go index 2efd172..633fb81 100644 --- a/keyserv/rpc_client_test.go +++ b/keyserv/rpc_client_test.go @@ -3,23 +3,49 @@ package keyserv import ( + "fmt" "github.com/HouzuoGuo/cryptctl/keydb" "github.com/HouzuoGuo/cryptctl/sys" "path" "reflect" "runtime" "strconv" + "strings" "testing" "time" ) +func TestCreateKeyReq_Validate(t *testing.T) { + req := CreateKeyReq{} + if err := req.Validate(); err == nil || !strings.Contains(err.Error(), "UUID must not be empty") { + t.Fatal(err) + } + req.UUID = "/root/../a-" + if err := req.Validate(); err == nil || !strings.Contains(err.Error(), "Illegal chara") { + t.Fatal(err) + } + req.UUID = "abc-def-123-ghi" + if err := req.Validate(); err == nil || !strings.Contains(err.Error(), "Mount point") { + t.Fatal(err) + } + req.MountPoint = "/a" + if err := req.Validate(); err != nil { + t.Fatal(err) + } +} + func TestRPCCalls(t *testing.T) { client, tearDown := StartTestServer(t) defer tearDown(t) - if err := client.Ping(PingRequest{Password: "wrong pass"}); err == nil { + // Retrieve server's password salt + salt, err := client.GetSalt() + if err != nil { + t.Fatal(err) + } + if err := client.Ping(PingRequest{Password: HashPassword(salt, "wrong password")}); err == nil { t.Fatal("did not error") } - if err := client.Ping(PingRequest{Password: TEST_RPC_PASS}); err != nil { + if err := client.Ping(PingRequest{Password: HashPassword(salt, TEST_RPC_PASS)}); err != nil { t.Fatal(err) } // Construct a client via sysconfig @@ -31,55 +57,66 @@ func TestRPCCalls(t *testing.T) { if err != nil { t.Fatal(err) } - if err := scClient.Ping(PingRequest{Password: TEST_RPC_PASS}); err != nil { + if err := scClient.Ping(PingRequest{Password: HashPassword(salt, TEST_RPC_PASS)}); err != nil { t.Fatal(err) } - // Save a bogus key will result in error - err = client.SaveKey(SaveKeyReq{Password: TEST_RPC_PASS, Hostname: "localhost", Record: keydb.Record{}}) + // Refuse to save a key if password is incorrect + createResp, err := client.CreateKey(CreateKeyReq{ + Password: HashPassword(salt, "wrong password"), + Hostname: "localhost", + UUID: "aaa", + MountPoint: "/a", + MountOptions: []string{"ro", "noatime"}, + MaxActive: 1, + AliveIntervalSec: 1, + AliveCount: 4, + }) if err == nil { - t.Fatal(err) + t.Fatal("did not error") } - // Save two keys - keyRec1 := keydb.Record{ + // Save two good keys + createResp, err = client.CreateKey(CreateKeyReq{ + Password: HashPassword(salt, TEST_RPC_PASS), + Hostname: "localhost", UUID: "aaa", - Key: []byte{0, 1, 2, 3}, MountPoint: "/a", MountOptions: []string{"ro", "noatime"}, MaxActive: 1, AliveIntervalSec: 1, AliveCount: 4, + }) + if err != nil || len(createResp.KeyContent) != KMIPAESKeySizeBits/8 { + t.Fatal(err) } - keyRec2 := keydb.Record{ + createResp, err = client.CreateKey(CreateKeyReq{ + Password: HashPassword(salt, TEST_RPC_PASS), + Hostname: "localhost", UUID: "bbb", - Key: []byte{0, 1, 2, 3}, MountPoint: "/b", MountOptions: []string{"ro", "noatime"}, MaxActive: 0, AliveIntervalSec: 1, AliveCount: 4, - } - if err := client.SaveKey(SaveKeyReq{Password: "wrong pass", Hostname: "localhost", Record: keyRec1}); err == nil { - t.Fatal("did not error") - } - if err := client.SaveKey(SaveKeyReq{Password: TEST_RPC_PASS, Hostname: "localhost", Record: keyRec1}); err != nil { - t.Fatal(err) - } - if err := client.SaveKey(SaveKeyReq{Password: TEST_RPC_PASS, Hostname: "localhost", Record: keyRec2}); err != nil { + }) + if err != nil || len(createResp.KeyContent) != KMIPAESKeySizeBits/8 { t.Fatal(err) } - // Retrieve both keys without password - resp, err := client.AutoRetrieveKey(AutoRetrieveKeyReq{ + // Retrieve both keys via automated retrieval without password + autoRetrieveResp, err := client.AutoRetrieveKey(AutoRetrieveKeyReq{ UUIDs: []string{"aaa", "bbb", "does_not_exist"}, Hostname: "localhost", }) if err != nil { t.Fatal(err) } - if len(resp.Granted) != 2 || len(resp.Rejected) != 0 || !reflect.DeepEqual(resp.Missing, []string{"does_not_exist"}) { - t.Fatal(resp.Granted, resp.Rejected, resp.Missing) + if len(autoRetrieveResp.Granted) != 2 || len(autoRetrieveResp.Rejected) != 0 || !reflect.DeepEqual(autoRetrieveResp.Missing, []string{"does_not_exist"}) { + t.Fatal(autoRetrieveResp.Granted, autoRetrieveResp.Rejected, autoRetrieveResp.Missing) + } + if len(autoRetrieveResp.Granted["aaa"].Key) != KMIPAESKeySizeBits/8 || len(autoRetrieveResp.Granted["bbb"].Key) != KMIPAESKeySizeBits/8 { + t.Fatal(autoRetrieveResp.Granted) } verifyKeyA := func(recA keydb.Record) { - if recA.UUID != "aaa" || !reflect.DeepEqual(recA.Key, []byte{0, 1, 2, 3}) || recA.MountPoint != "/a" || + if recA.UUID != "aaa" || recA.MountPoint != "/a" || !reflect.DeepEqual(recA.MountOptions, []string{"ro", "noatime"}) || recA.AliveIntervalSec != 1 || recA.AliveCount != 4 || recA.LastRetrieval.Timestamp == 0 || recA.LastRetrieval.Hostname != "localhost" || recA.LastRetrieval.IP != "127.0.0.1" || len(recA.AliveMessages["127.0.0.1"]) != 1 { @@ -87,41 +124,39 @@ func TestRPCCalls(t *testing.T) { } } verifyKeyB := func(recB keydb.Record) { - if recB.UUID != "bbb" || !reflect.DeepEqual(recB.Key, []byte{0, 1, 2, 3}) || recB.MountPoint != "/b" || + if recB.UUID != "bbb" || recB.MountPoint != "/b" || !reflect.DeepEqual(recB.MountOptions, []string{"ro", "noatime"}) || recB.AliveIntervalSec != 1 || recB.AliveCount != 4 || recB.LastRetrieval.Timestamp == 0 || recB.LastRetrieval.Hostname != "localhost" || recB.LastRetrieval.IP != "127.0.0.1" || len(recB.AliveMessages) != 1 || len(recB.AliveMessages["127.0.0.1"]) != 1 { t.Fatal(recB) } } - // Verify retrieved keys - verifyKeyA(resp.Granted["aaa"]) - verifyKeyB(resp.Granted["bbb"]) + verifyKeyA(autoRetrieveResp.Granted["aaa"]) + verifyKeyB(autoRetrieveResp.Granted["bbb"]) - // Retrieve a key for a second time should be checked against MaxActive allowrance - resp, err = client.AutoRetrieveKey(AutoRetrieveKeyReq{ + // Retrieve a key for a second time should be checked against MaxActive limit + autoRetrieveResp, err = client.AutoRetrieveKey(AutoRetrieveKeyReq{ UUIDs: []string{"aaa", "bbb", "does_not_exist"}, Hostname: "localhost", }) if err != nil { t.Fatal(err) } - if len(resp.Granted) != 1 || !reflect.DeepEqual(resp.Rejected, []string{"aaa"}) || !reflect.DeepEqual(resp.Missing, []string{"does_not_exist"}) { - t.Fatal(resp.Granted, resp.Rejected, resp.Missing) + if len(autoRetrieveResp.Granted) != 1 || !reflect.DeepEqual(autoRetrieveResp.Rejected, []string{"aaa"}) || !reflect.DeepEqual(autoRetrieveResp.Missing, []string{"does_not_exist"}) { + t.Fatal(autoRetrieveResp.Granted, autoRetrieveResp.Rejected, autoRetrieveResp.Missing) } - // Verify retrieved key bbb - verifyKeyB(resp.Granted["bbb"]) + verifyKeyB(autoRetrieveResp.Granted["bbb"]) // Forcibly retrieve both keys and verify if _, err := client.ManualRetrieveKey(ManualRetrieveKeyReq{ - Password: "wrong password", + Password: HashPassword(salt, "wrong password"), UUIDs: []string{"aaa"}, Hostname: "localhost", }); err == nil { t.Fatal("did not error") } manResp, err := client.ManualRetrieveKey(ManualRetrieveKeyReq{ - Password: TEST_RPC_PASS, + Password: HashPassword(salt, TEST_RPC_PASS), UUIDs: []string{"aaa", "bbb", "does_not_exist"}, Hostname: "localhost", }) @@ -152,38 +187,44 @@ func TestRPCCalls(t *testing.T) { // Delete key if err := client.EraseKey(EraseKeyReq{ - Password: "wrongpass", + Password: HashPassword(salt, "wrong password"), Hostname: "localhost", UUID: "aaa", }); err == nil { t.Fatal("did not error") } if err := client.EraseKey(EraseKeyReq{ - Password: TEST_RPC_PASS, + Password: HashPassword(salt, TEST_RPC_PASS), Hostname: "localhost", UUID: "doesnotexist", }); err == nil { t.Fatal("did not error") } if err := client.EraseKey(EraseKeyReq{ - Password: TEST_RPC_PASS, + Password: HashPassword(salt, TEST_RPC_PASS), Hostname: "localhost", UUID: "aaa", }); err != nil { t.Fatal(err) } if err := client.EraseKey(EraseKeyReq{ - Password: TEST_RPC_PASS, + Password: HashPassword(salt, TEST_RPC_PASS), Hostname: "localhost", UUID: "aaa", }); err == nil { t.Fatal("did not error") } + fmt.Println("About to run teardown") } func BenchmarkSaveKey(b *testing.B) { client, tearDown := StartTestServer(b) defer tearDown(b) + // Retrieve server's password salt + salt, err := client.GetSalt() + if err != nil { + b.Fatal(err) + } // Run all transactions in a single goroutine oldMaxprocs := runtime.GOMAXPROCS(-1) defer runtime.GOMAXPROCS(oldMaxprocs) @@ -191,19 +232,15 @@ func BenchmarkSaveKey(b *testing.B) { b.ResetTimer() // The benchmark will run all RPC operations consecutively for i := 0; i < b.N; i++ { - rec := keydb.Record{ + if _, err := client.CreateKey(CreateKeyReq{ + Password: HashPassword(salt, TEST_RPC_PASS), + Hostname: "localhost", UUID: "aaa", - Key: []byte{0, 1, 2, 3}, MountPoint: "/a", MountOptions: []string{"ro", "noatime"}, MaxActive: -1, AliveIntervalSec: 1, AliveCount: 4, - } - if err := client.SaveKey(SaveKeyReq{ - Password: TEST_RPC_PASS, - Hostname: "localhost", - Record: rec, }); err != nil { b.Fatal(err) } @@ -214,23 +251,24 @@ func BenchmarkSaveKey(b *testing.B) { func BenchmarkAutoRetrieveKey(b *testing.B) { client, tearDown := StartTestServer(b) defer tearDown(b) + // Retrieve server's password salt + salt, err := client.GetSalt() + if err != nil { + b.Fatal(err) + } // Run all transactions in a single goroutine oldMaxprocs := runtime.GOMAXPROCS(-1) defer runtime.GOMAXPROCS(oldMaxprocs) runtime.GOMAXPROCS(1) - rec := keydb.Record{ + if _, err := client.CreateKey(CreateKeyReq{ + Password: HashPassword(salt, TEST_RPC_PASS), + Hostname: "localhost", UUID: "aaa", - Key: []byte{0, 1, 2, 3}, MountPoint: "/a", MountOptions: []string{"ro", "noatime"}, MaxActive: -1, AliveIntervalSec: 1, AliveCount: 4, - } - if err := client.SaveKey(SaveKeyReq{ - Password: TEST_RPC_PASS, - Hostname: "localhost", - Record: rec, }); err != nil { b.Fatal(err) } @@ -250,23 +288,24 @@ func BenchmarkAutoRetrieveKey(b *testing.B) { func BenchmarkManualRetrieveKey(b *testing.B) { client, tearDown := StartTestServer(b) defer tearDown(b) + // Retrieve server's password salt + salt, err := client.GetSalt() + if err != nil { + b.Fatal(err) + } // Run all transactions in a single goroutine oldMaxprocs := runtime.GOMAXPROCS(-1) defer runtime.GOMAXPROCS(oldMaxprocs) runtime.GOMAXPROCS(1) - rec := keydb.Record{ + if _, err := client.CreateKey(CreateKeyReq{ + Password: HashPassword(salt, TEST_RPC_PASS), + Hostname: "localhost", UUID: "aaa", - Key: []byte{0, 1, 2, 3}, MountPoint: "/a", MountOptions: []string{"ro", "noatime"}, MaxActive: -1, AliveIntervalSec: 1, AliveCount: 4, - } - if err := client.SaveKey(SaveKeyReq{ - Password: TEST_RPC_PASS, - Hostname: "localhost", - Record: rec, }); err != nil { b.Fatal(err) } @@ -274,7 +313,7 @@ func BenchmarkManualRetrieveKey(b *testing.B) { // The benchmark will run all RPC operations consecutively for i := 0; i < b.N; i++ { if resp, err := client.ManualRetrieveKey(ManualRetrieveKeyReq{ - Password: TEST_RPC_PASS, + Password: HashPassword(salt, TEST_RPC_PASS), UUIDs: []string{"aaa"}, Hostname: "localhost", }); err != nil || len(resp.Granted) != 1 { @@ -287,29 +326,30 @@ func BenchmarkManualRetrieveKey(b *testing.B) { func BenchmarkReportAlive(b *testing.B) { client, tearDown := StartTestServer(b) defer tearDown(b) + // Retrieve server's password salt + salt, err := client.GetSalt() + if err != nil { + b.Fatal(err) + } // Run all benchmark operations in a single goroutine to know the real performance oldMaxprocs := runtime.GOMAXPROCS(-1) defer runtime.GOMAXPROCS(oldMaxprocs) runtime.GOMAXPROCS(1) - rec := keydb.Record{ + if _, err := client.CreateKey(CreateKeyReq{ + Password: HashPassword(salt, TEST_RPC_PASS), + Hostname: "localhost", UUID: "aaa", - Key: []byte{0, 1, 2, 3}, MountPoint: "/a", MountOptions: []string{"ro", "noatime"}, MaxActive: -1, AliveIntervalSec: 1, AliveCount: 4, - } - if err := client.SaveKey(SaveKeyReq{ - Password: TEST_RPC_PASS, - Hostname: "localhost", - Record: rec, }); err != nil { b.Fatal(err) } // Retrieve the key so that this computer becomes eligible to send alive messages if resp, err := client.ManualRetrieveKey(ManualRetrieveKeyReq{ - Password: TEST_RPC_PASS, + Password: HashPassword(salt, TEST_RPC_PASS), UUIDs: []string{"aaa"}, Hostname: "localhost", }); err != nil || len(resp.Granted) != 1 { diff --git a/keyserv/rpc_svc.go b/keyserv/rpc_svc.go index fecdb90..c445d28 100644 --- a/keyserv/rpc_svc.go +++ b/keyserv/rpc_svc.go @@ -7,17 +7,20 @@ import ( "crypto/sha512" "crypto/subtle" "crypto/tls" + "crypto/x509" "encoding/hex" "errors" "fmt" "github.com/HouzuoGuo/cryptctl/fs" "github.com/HouzuoGuo/cryptctl/keydb" "github.com/HouzuoGuo/cryptctl/sys" + "io/ioutil" "log" "net" "net/rpc" "os" "path" + "path/filepath" "reflect" "strings" "time" @@ -30,8 +33,10 @@ const ( SRV_CONF_PASS_HASH = "AUTH_PASSWORD_HASH" SRV_CONF_PASS_SALT = "AUTH_PASSWORD_SALT" + SRV_CONF_TLS_CA = "TLS_CA_PEM" SRV_CONF_TLS_CERT = "TLS_CERT_PEM" SRV_CONF_TLS_KEY = "TLS_CERT_KEY_PEM" + SRV_CONF_TLS_VALIDATE_CLIENT = "TLS_VALIDATE_CLIENT" SRV_CONF_LISTEN_ADDR = "LISTEN_ADDRESS" SRV_CONF_LISTEN_PORT = "LISTEN_PORT" SRV_CONF_KEYDB_DIR = "KEY_DB_DIR" @@ -40,10 +45,13 @@ const ( SRV_CONF_MAIL_RETRIEVAL_SUBJ = "EMAIL_KEY_RETRIEVAL_SUBJECT" SRV_CONF_MAIL_RETRIEVAL_TEXT = "EMAIL_KEY_RETRIEVAL_GREETING" - SRV_CONF_KMIP_SERVER_HOST = "KMIP_SERVER_HOST" - SRV_CONF_KMIP_SERVER_PORT = "KMIP_SERVER_PORT" - SRV_CONF_KMIP_SERVER_USER = "KMIP_SERVER_USER" - SRV_CONF_KMIP_SERVER_PASS = "KMIP_SERVER_PASS" + SRV_CONF_KMIP_SERVER_HOST = "KMIP_SERVER_HOST" + SRV_CONF_KMIP_SERVER_PORT = "KMIP_SERVER_PORT" + SRV_CONF_KMIP_SERVER_USER = "KMIP_SERVER_USER" + SRV_CONF_KMIP_SERVER_PASS = "KMIP_SERVER_PASS" + SRV_CONF_KMIP_SERVER_TLS_CA = "KMIP_CA_PEM" + SRV_CONF_KMIP_SERVER_TLS_CERT = "KMIP_TLS_CERT_PEM" + SRV_CONF_KMIP_SERVER_TLS_KEY = "KMIP_TLS_CERT_KEY_PEM" ) var PkgInGopath = path.Join(path.Join(os.Getenv("GOPATH"), "/src/github.com/HouzuoGuo/cryptctl")) // this package in gopath @@ -57,15 +65,18 @@ func GetDefaultKeySvcConf() *sys.Sysconfig { } // Return a newly generated salt for hasing passwords. -func NewSalt() (ret [LEN_PASS_SALT]byte) { +func NewSalt() (ret PasswordSalt) { if _, err := rand.Read(ret[:]); err != nil { panic(fmt.Errorf("NewSalt: failed to read from random source - %v", err)) } return } +type PasswordSalt [LEN_PASS_SALT]byte +type HashedPassword [sha512.Size]byte + // Compute a salted password hash using SHA512 method. -func HashPassword(salt [LEN_PASS_SALT]byte, plainText string) [sha512.Size]byte { +func HashPassword(salt PasswordSalt, plainText string) HashedPassword { plainBytes := []byte(plainText) // saltedBytes = salt + plainBytes saltedBytes := make([]byte, LEN_PASS_SALT+len(plainBytes)) @@ -79,6 +90,8 @@ func HashPassword(salt [LEN_PASS_SALT]byte, plainText string) [sha512.Size]byte type CryptServiceConfig struct { PasswordHash [sha512.Size]byte // password hash (salted) that authenticates incoming requests PasswordSalt [LEN_PASS_SALT]byte // password hash salt + CertAuthorityPEM string // path to PEM-encoded CA certificate + ValidateClientCert bool // whether the server will authenticate its client before accepting RPC request CertPEM string // path to PEM-encoded TLS certificate KeyPEM string // path to PEM-encoded TLS certificate key Address string // address of the network interface to listen on @@ -92,6 +105,9 @@ type CryptServiceConfig struct { KMIPPort int // optional KMIP server port KMIPUser string // optional KMIP service access user KMIPPass string // optional KMIP service access password + KMIPCertAuthorityPEM string // optional KMIP server CA certificate + KMIPCertPEM string // optional KMIP client certificate + KMIPKeyPEM string // optional KMIP client certificate key } // Preliminarily validate configuration and report error. @@ -123,6 +139,8 @@ func (conf *CryptServiceConfig) ReadFromSysconfig(sysconf *sys.Sysconfig) error copy(conf.PasswordHash[:], passwordHash) copy(conf.PasswordSalt[:], passwordSalt) + conf.CertAuthorityPEM = sysconf.GetString(SRV_CONF_TLS_CA, "") + conf.ValidateClientCert = sysconf.GetBool(SRV_CONF_TLS_VALIDATE_CLIENT, false) conf.CertPEM = sysconf.GetString(SRV_CONF_TLS_CERT, "") conf.KeyPEM = sysconf.GetString(SRV_CONF_TLS_KEY, "") conf.Address = sysconf.GetString(SRV_CONF_LISTEN_ADDR, "0.0.0.0") @@ -139,6 +157,9 @@ func (conf *CryptServiceConfig) ReadFromSysconfig(sysconf *sys.Sysconfig) error conf.KMIPPort = sysconf.GetInt(SRV_CONF_KMIP_SERVER_PORT, 0) conf.KMIPUser = sysconf.GetString(SRV_CONF_KMIP_SERVER_USER, "") conf.KMIPPass = sysconf.GetString(SRV_CONF_KMIP_SERVER_PASS, "") + conf.KMIPCertAuthorityPEM = sysconf.GetString(SRV_CONF_KMIP_SERVER_TLS_CA, "") + conf.KMIPCertPEM = sysconf.GetString(SRV_CONF_KMIP_SERVER_TLS_CERT, "") + conf.KMIPKeyPEM = sysconf.GetString(SRV_CONF_KMIP_SERVER_TLS_KEY, "") return conf.Validate() } @@ -149,13 +170,15 @@ type CryptServer struct { KeyDB *keydb.DB // encryption key database TLSConfig *tls.Config // TLS certificate chain and private key Listener net.Listener // listener for client connections + BuiltInKMIPServer *KMIPServer // Built-in KMIP server in case there's no external server + KMIPClient *KMIPClient // KMIP client connected to either built-in KMIP server or external server ShutdownChallenge []byte // a random secret that must be verified for incoming shutdown requests } // Initialise an RPC server from sysconfig file text. func NewCryptServer(config CryptServiceConfig, mailer Mailer) (srv *CryptServer, err error) { if err = config.Validate(); err != nil { - return + return nil, err } srv = &CryptServer{ Config: config, @@ -164,10 +187,22 @@ func NewCryptServer(config CryptServiceConfig, mailer Mailer) (srv *CryptServer, } srv.KeyDB, err = keydb.OpenDB(config.KeyDBDir) if err != nil { - return + return nil, err } srv.TLSConfig.Certificates = make([]tls.Certificate, 1) srv.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(config.CertPEM, config.KeyPEM) + // Configure client authentication upon request + if config.ValidateClientCert { + caPEM, err := ioutil.ReadFile(config.CertAuthorityPEM) + if err != nil { + return nil, err + } + caPool := x509.NewCertPool() + caPool.AppendCertsFromPEM(caPEM) + srv.TLSConfig.ClientCAs = caPool + srv.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert + } + srv.TLSConfig.BuildNameToCertificate() // Shutdown challenge is an array of random bytes srv.ShutdownChallenge = make([]byte, LEN_SHUTDOWN_CHALLENGE) if _, err = rand.Read(srv.ShutdownChallenge); err != nil { @@ -176,19 +211,61 @@ func NewCryptServer(config CryptServiceConfig, mailer Mailer) (srv *CryptServer, return } -// Start RPC server and block until the server listener is told to shut down. -func (srv *CryptServer) ListenRPC() (err error) { - // It is not necessary to validate initial server setup, because ping should always work. - srv.Listener, err = tls.Listen("tcp", fmt.Sprintf("%s:%d", srv.Config.Address, srv.Config.Port), srv.TLSConfig) - if err != nil { - return fmt.Errorf("ListenRPC: failed to listen on %s:%d - %v", srv.Config.Address, srv.Config.Port, err) +/* +Start RPC server. If the RPC server does not have KMIP connectivity settings, start an incomplete implementation +of KMIP server. +Block caller until the listener quits. +*/ +func (srv *CryptServer) ListenRPC() error { + var err error + if srv.Config.KMIPHost == "" { + // If RPC server settings do not have KMIP connectivity settings, start my own KMIP server. + if srv.BuiltInKMIPServer, err = NewKMIPServer(srv.KeyDB, srv.Config.CertPEM, srv.Config.KeyPEM); err != nil { + return err + } + if err := srv.BuiltInKMIPServer.Listen(); err != nil { + return err + } + go srv.BuiltInKMIPServer.HandleConnections() + // The client initialisation routine does not immediately connect to server. + if srv.KMIPClient, err = NewKMIPClient( + "localhost", srv.BuiltInKMIPServer.GetPort(), + "does-not-matter", string(srv.BuiltInKMIPServer.PasswordChallenge), + nil, "", ""); err != nil { + return err + } + srv.KMIPClient.TLSConfig.InsecureSkipVerify = true + } else { + // No need to start built-in KMIP server, so only initialise the client. + var caCert []byte + if srv.Config.KMIPCertAuthorityPEM != "" { + caCert, err = ioutil.ReadFile(srv.Config.KMIPCertAuthorityPEM) + if err != nil { + return err + } + } + if srv.KMIPClient, err = NewKMIPClient( + srv.Config.KMIPHost, srv.Config.KMIPPort, + srv.Config.KMIPUser, srv.Config.KMIPPass, + caCert, srv.Config.KMIPCertPEM, srv.Config.KMIPKeyPEM); err != nil { + return err + } } - log.Printf("ListenRPC: listening on %s:%d using TLS certficate \"%s\"", srv.Config.Address, srv.Config.Port, srv.Config.CertPEM) + // Start ordinary RPC server + if srv.Listener, err = tls.Listen("tcp", fmt.Sprintf("%s:%d", srv.Config.Address, srv.Config.Port), srv.TLSConfig); err != nil { + return fmt.Errorf("CryptServer.ListenRPC: failed to listen on %s:%d - %v", srv.Config.Address, srv.Config.Port, err) + } + log.Printf("CryptServer.ListenRPC: listening on %s:%d using TLS certficate \"%s\"", srv.Config.Address, srv.Config.Port, srv.Config.CertPEM) + return nil +} + +// Handle incoming connections in a loop. Block caller until listener closes. +func (srv *CryptServer) HandleConnections() { for { incoming, err := srv.Listener.Accept() if err != nil { - log.Printf("ListenRPC: quit now - %v", err) - return nil + log.Printf("CryptServer.HandleConnections: quit now - %v", err) + return } // The connection is served by a dedicated RPC server instance go func(conn net.Conn) { @@ -196,7 +273,16 @@ func (srv *CryptServer) ListenRPC() (err error) { conn.Close() }(incoming) } - return nil +} + +// Shut down RPC server listener. If built-in KMIP server was started, shut that one down as well. +func (srv *CryptServer) Shutdown() { + if listener := srv.Listener.Close(); listener != nil { + srv.Listener.Close() + } + if kmipServer := srv.BuiltInKMIPServer; kmipServer != nil { + kmipServer.Shutdown() + } } /* @@ -225,13 +311,12 @@ func (srv *CryptServer) CheckInitialSetup() error { } // Validate a password against stored hash. -func (srv *CryptServer) ValidatePassword(plainText string) error { +func (srv *CryptServer) ValidatePassword(pass HashedPassword) error { // Fail straight away if server setup is missing if err := srv.CheckInitialSetup(); err != nil { return err } - hashInput := HashPassword(srv.Config.PasswordSalt, plainText) - if subtle.ConstantTimeCompare(hashInput[:], srv.Config.PasswordHash[:]) != 1 { + if subtle.ConstantTimeCompare(pass[:], srv.Config.PasswordHash[:]) != 1 { return errors.New("ValidatePassword: password is incorrect") } return nil @@ -258,7 +343,7 @@ var RPCObjNameFmt = reflect.TypeOf(CryptServiceConn{}).Name() + ".%s" // for con // A request to ping server and test its readiness for key operations. type PingRequest struct { - Password string // access is only granted after correct password is given + Password HashedPassword // access is only granted after correct password is given } // If the server is ready to manage encryption keys, return nothing successfully. Return an error if otherwise. @@ -274,39 +359,84 @@ func (rpcConn *CryptServiceConn) Ping(req PingRequest, _ *DummyAttr) error { type DummyAttr bool // dummy type for a placeholder receiver in an RPC function -// A request to upload an encryption key to server. -type SaveKeyReq struct { - Password string // access is granted only after the correct password is given - Hostname string // client's host name (for logging only) - Record keydb.Record // the new key record +// A request to create an encryption key on server. +type CreateKeyReq struct { + Password HashedPassword // access is granted only after the correct password is given + Hostname string // computer host name (for logging only) + UUID string // file system uuid + MountPoint string // mount point of the file system + MountOptions []string // mount options of the file system + MaxActive int // maximum allowed active key users (computers), set to <=0 to allow unlimited. + AliveIntervalSec int //interval in seconds at which all user of the file system holding this key must report they're online + AliveCount int //a computer holding the file system is considered offline after missing so many alive messages +} + +// Make sure that the request attributes are sane. +func (req CreateKeyReq) Validate() error { + if req.UUID == "" { + return errors.New("UUID must not be empty") + } else if cleanedID := filepath.Clean(req.UUID); cleanedID != req.UUID { + return errors.New("Illegal characters appeared in UUID") + } else if req.MountPoint == "" { + return errors.New("Mount point must not be empty") + } + return nil +} + +// A response to a newly saved key +type CreateKeyResp struct { + KeyContent []byte // Disk encryption key } // Save a new key record. -func (rpcConn *CryptServiceConn) SaveKey(req SaveKeyReq, _ *DummyAttr) error { +func (rpcConn *CryptServiceConn) CreateKey(req CreateKeyReq, resp *CreateKeyResp) error { if err := rpcConn.Svc.ValidatePassword(req.Password); err != nil { return err - } - // Input record may not contain empty attributes - req.Record.FillBlanks() - // The requester is considered to be the last host to have "retrieved" the key - req.Record.LastRetrieval = keydb.AliveMessage{ - Hostname: req.Hostname, - IP: rpcConn.RemoteAddr, - Timestamp: time.Now().Unix(), - } - // Input record must be validated before saving - if err := req.Record.Validate(); err != nil { + } else if err := req.Validate(); err != nil { return err } - // TODO: insert KMIP logic here - if _, err := rpcConn.Svc.KeyDB.Upsert(req.Record); err != nil { + // No matter key is located in built-in KMIP server or external KMIP server, the KMIP client needs to create the key. + kmipKeyID, err := rpcConn.Svc.KMIPClient.CreateKey(req.UUID) + if err != nil { + return fmt.Errorf("CryptServiceConn.CreateKey: KMIP client refused to create the key - %v", err) + } + // Complete key tracking record in my database + var keyRecord keydb.Record + if rpcConn.Svc.BuiltInKMIPServer != nil { + // Retrieve the incomplete key record saved by built-in KMIP server + var found bool + keyRecord, found = rpcConn.Svc.KeyDB.GetByID(kmipKeyID) + if !found { + return fmt.Errorf("CryptServiceConn.CreateKey: new key ID \"%s\" just disappeared from database", kmipKeyID) + } + } + /* + If the record was created by built-in KMIP server, some record details are already in-place. + But if external KMIP server was used, then the key record does not yet even exist in my database, hence complete all + record details no matter what. + */ + keyRecord.ID = kmipKeyID + keyRecord.Version = keydb.CurrentRecordVersion + keyRecord.CreationTime = time.Now() + keyRecord.UUID = req.UUID + keyRecord.MountPoint = req.MountPoint + keyRecord.MountOptions = req.MountOptions + keyRecord.MaxActive = req.MaxActive + keyRecord.AliveIntervalSec = req.AliveIntervalSec + keyRecord.AliveCount = req.AliveCount + if _, err := rpcConn.Svc.KeyDB.Upsert(keyRecord); err != nil { + return fmt.Errorf("CryptServiceConn.CreateKey: failed to save key tracking record into database - %v", err) + } + // Ask server for the actual encryption key to formulate RPC response + resp.KeyContent, err = rpcConn.askForKeyContent(kmipKeyID) + if err != nil { return err } // Format a record for journal - journalRec := req.Record + journalRec := keyRecord journalRec.Key = nil - // Always log to system journal - log.Printf(`SaveKey: %s (%s) has saved new key %s`, + // Always log the event to system journal + log.Printf(`CryptServiceConn.CreateKey: %s (%s) has saved new key %s`, rpcConn.RemoteAddr, req.Hostname, journalRec.FormatAttrs(" ")) // Send optional notification email in background if rpcConn.Svc.Mailer.ValidateConfig() == nil { @@ -316,11 +446,12 @@ func (rpcConn *CryptServiceConn) SaveKey(req SaveKeyReq, _ *DummyAttr) error { rpcConn.RemoteAddr, req.Hostname, journalRec.MountPoint) text := fmt.Sprintf("%s\r\n\r\n%s", rpcConn.Svc.Config.KeyCreationGreeting, journalRec.FormatAttrs("\r\n")) if err := rpcConn.Svc.Mailer.Send(subject, text); err != nil { - log.Printf("SaveKey: failed to send email notification after saving %s (%s)'s key of %s - %v", - rpcConn.RemoteAddr, req.Hostname, req.Record.MountPoint, err) + log.Printf("CryptServiceConn.CreateKey: failed to send email notification after saving %s (%s)'s key of %s - %v", + rpcConn.RemoteAddr, req.Hostname, journalRec.MountPoint, err) } }() } + return nil } @@ -332,11 +463,11 @@ func (rpcConn *CryptServiceConn) logRetrieval(uuids []string, hostname string, g retrievedUUIDs = append(retrievedUUIDs, uuid) } if len(granted) > 0 { - log.Printf(`RetrieveKey: %s (%s) has been granted keys of: %s`, + log.Printf(`CryptServiceConn.logRetrieval: %s (%s) has been granted keys of: %s`, rpcConn.RemoteAddr, hostname, strings.Join(retrievedUUIDs, " ")) } if len(rejected) > 0 { - log.Printf(`RetrieveKey: %s (%s) has been rejected keys of: %s`, + log.Printf(`CryptServiceConn.logRetrieval: %s (%s) has been rejected keys of: %s`, rpcConn.RemoteAddr, hostname, strings.Join(rejected, " ")) } // There is really no need to log the missing keys @@ -350,7 +481,7 @@ func (rpcConn *CryptServiceConn) logRetrieval(uuids []string, hostname string, g text += fmt.Sprintf("%s - %s\r\n", uuid, record.MountPoint) } if err := rpcConn.Svc.Mailer.Send(subject, text); err != nil { - log.Printf("RetrieveKey: failed to send email notification after granting keys to %s (%s) - %v", + log.Printf("CryptServiceConn.logRetrieval: failed to send email notification after granting keys to %s (%s) - %v", rpcConn.RemoteAddr, hostname, err) } }(granted) @@ -370,6 +501,18 @@ type AutoRetrieveKeyResp struct { Missing []string // these keys cannot be found in database } +// Retrieve key content by KMIP record ID. Return key content. +func (rpcConn *CryptServiceConn) askForKeyContent(kmipID string) (key []byte, err error) { + key, err = rpcConn.Svc.KMIPClient.GetKey(kmipID) + if err != nil { + // This is severe enough to deserve a server side log message + msg := fmt.Sprintf("CryptServiceConn.askForKeyContent: KMIP client failed to answer to key request - %v", err) + log.Print(msg) + return nil, errors.New(msg) + } + return +} + // Retrieve encryption keys without using a password. The request is usually sent automatically when disk comes online. func (rpcConn *CryptServiceConn) AutoRetrieveKey(req AutoRetrieveKeyReq, resp *AutoRetrieveKeyResp) error { // Retrieve the keys and write down who retrieved it @@ -379,15 +522,24 @@ func (rpcConn *CryptServiceConn) AutoRetrieveKey(req AutoRetrieveKeyReq, resp *A Timestamp: time.Now().Unix(), } resp.Granted, resp.Rejected, resp.Missing = rpcConn.Svc.KeyDB.Select(requester, true, req.UUIDs...) + // Key content of granted records are stored in KMIP + for uuid, grantedRecord := range resp.Granted { + key, err := rpcConn.askForKeyContent(grantedRecord.ID) + if err != nil { + return err + } + grantedRecord.Key = key + resp.Granted[uuid] = grantedRecord + } rpcConn.logRetrieval(req.UUIDs, req.Hostname, resp.Granted, resp.Rejected, resp.Missing) return nil } // A request to forcibly retrieve encryption keys using a password. type ManualRetrieveKeyReq struct { - Password string // access to keys is granted only after the correct password is given. - UUIDs []string // (locked) file system UUIDs - Hostname string // client's host name (for logging only) + Password HashedPassword // access to keys is granted only after the correct password is given. + UUIDs []string // (locked) file system UUIDs + Hostname string // client's host name (for logging only) } // A response to forced key retrieval (with password) request. @@ -408,6 +560,15 @@ func (rpcConn *CryptServiceConn) ManualRetrieveKey(req ManualRetrieveKeyReq, res Timestamp: time.Now().Unix(), } resp.Granted, _, resp.Missing = rpcConn.Svc.KeyDB.Select(requester, false, req.UUIDs...) + // Key content of granted records are stored in KMIP + for uuid, grantedRecord := range resp.Granted { + key, err := rpcConn.askForKeyContent(grantedRecord.ID) + if err != nil { + return err + } + grantedRecord.Key = key + resp.Granted[uuid] = grantedRecord + } rpcConn.logRetrieval(req.UUIDs, req.Hostname, resp.Granted, []string{}, resp.Missing) return nil } @@ -435,9 +596,9 @@ func (rpcConn *CryptServiceConn) ReportAlive(req ReportAliveReq, rejectedUUIDs * // A request to erase an encryption key. type EraseKeyReq struct { - Password string // access is granted only after the correct password is given - Hostname string // client's host name (for logging only) - UUID string // UUID of the disk to delete key for + Password HashedPassword // access is granted only after the correct password is given + Hostname string // client's host name (for logging only) + UUID string // UUID of the disk to delete key for } func (rpcConn *CryptServiceConn) EraseKey(req EraseKeyReq, _ *DummyAttr) error { @@ -466,8 +627,14 @@ func (rpcConn *CryptServiceConn) Shutdown(req ShutdownReq, _ *DummyAttr) error { return errors.New("Shutdown: remote IP is not 127.0.0.1") } if subtle.ConstantTimeCompare(rpcConn.Svc.ShutdownChallenge, req.Challenge) != 1 { - return errors.New("Shutdown: incorrect pass") + return errors.New("Shutdown: incorrect challenge") } err := rpcConn.Svc.Listener.Close() return err } + +// Hand over the salt that was used to hash server's access password. +func (rpcConn *CryptServiceConn) GetSalt(_ DummyAttr, salt *PasswordSalt) error { + copy((*salt)[:], rpcConn.Svc.Config.PasswordSalt[:]) + return nil +} diff --git a/kmip/structure/serialisation_test.go b/kmip/structure/serialisation_test.go index 6c6a40a..3eca8c3 100644 --- a/kmip/structure/serialisation_test.go +++ b/kmip/structure/serialisation_test.go @@ -15,7 +15,7 @@ func TestSerialiseSimpleStruct(t *testing.T) { Tag: TagCredential, Typ: ttlv.TypStruct, }, - Items: []interface{}{ + Items: []ttlv.Item{ &ttlv.Enumeration{ TTL: ttlv.TTL{ Tag: TagCredentialType, @@ -28,7 +28,7 @@ func TestSerialiseSimpleStruct(t *testing.T) { Tag: TagCredentialValue, Typ: ttlv.TypStruct, }, - Items: []interface{}{ + Items: []ttlv.Item{ &ttlv.Text{ TTL: ttlv.TTL{ Tag: TagUsername, diff --git a/kmip/ttlv/dencode.go b/kmip/ttlv/dencode.go index 3d887cd..3bbda70 100644 --- a/kmip/ttlv/dencode.go +++ b/kmip/ttlv/dencode.go @@ -87,17 +87,20 @@ func EncodeIntBigEndian(someInt interface{}) []byte { return buf.Bytes() } -// Encode any TTLV item. Input must be pointer to item. -func EncodeAny(thing interface{}) (ret []byte) { +// Encode any TTLV item. Input must be pointer to item and must not be nil. +func EncodeAny(thing Item) (ret []byte) { buf := new(bytes.Buffer) + // Tolerate constructed TTLV items that did not carry a type byte switch t := thing.(type) { case *Structure: + t.ResetTyp() t.TTL.WriteTTTo(buf) buf.Write(EncodeIntBigEndian(int32(t.GetLength()))) for _, item := range t.Items { buf.Write(EncodeAny(item)) } case *Integer: + t.ResetTyp() t.TTL.WriteTTTo(buf) buf.Write(EncodeIntBigEndian(int32(t.GetLength()))) // Integer has length of 4 @@ -105,11 +108,13 @@ func EncodeAny(thing interface{}) (ret []byte) { // An additional 4 bytes of padding not counted against length buf.Write([]byte{0, 0, 0, 0}) case *LongInteger: + t.ResetTyp() t.TTL.WriteTTTo(buf) buf.Write(EncodeIntBigEndian(int32(t.GetLength()))) // LongInteger has length of 8 buf.Write(EncodeIntBigEndian(t.Value)) case *Enumeration: + t.ResetTyp() t.TTL.WriteTTTo(buf) buf.Write(EncodeIntBigEndian(int32(t.GetLength()))) // Enumeration has length of 4 @@ -117,11 +122,13 @@ func EncodeAny(thing interface{}) (ret []byte) { // An additional 4 bytes of padding not counted against length buf.Write([]byte{0, 0, 0, 0}) case *DateTime: + t.ResetTyp() t.TTL.WriteTTTo(buf) buf.Write(EncodeIntBigEndian(int32(t.GetLength()))) // DateTime has length of 8 buf.Write(EncodeIntBigEndian(t.Time.Unix())) case *Text: + t.ResetTyp() t.TTL.WriteTTTo(buf) buf.Write(EncodeIntBigEndian(int32(t.GetLength()))) buf.Write([]byte(t.Value)) @@ -129,12 +136,15 @@ func EncodeAny(thing interface{}) (ret []byte) { padding := make([]byte, RoundUpTo8(len(t.Value))-len(t.Value)) buf.Write(padding) case *Bytes: + t.ResetTyp() t.TTL.WriteTTTo(buf) buf.Write(EncodeIntBigEndian(int32(t.GetLength()))) buf.Write(t.Value) // Pad with zero bytes to line up with 8 padding := make([]byte, RoundUpTo8(len(t.Value))-len(t.Value)) buf.Write(padding) + default: + log.Panicf("EncodeAny: input is nil or type \"%s\"'s encoder is not implemented", reflect.TypeOf(thing).String()) } return buf.Bytes() } @@ -156,6 +166,8 @@ func DecodeAny(in []byte) (ret Item, length int, err error) { tag, typ, length32, err := DecodeTTL(in) length = int(length32) if err == io.EOF { + // The condition of reaching end of buffer is not an error + err = nil return } else if err != nil { return @@ -191,7 +203,7 @@ func DecodeAny(in []byte) (ret Item, length int, err error) { ret = long case TypStruct: in = in[:length] - structure := &Structure{TTL: common, Items: make([]interface{}, 0, 4)} + structure := &Structure{TTL: common, Items: make([]Item, 0, 4)} itemIndex := 0 for { // Decode item at current index @@ -229,6 +241,8 @@ func DecodeAny(in []byte) (ret Item, length int, err error) { default: return nil, length, fmt.Errorf("DecodeAny: does not know how to decode %s's type", common.TTLString()) } + // Type byte was not directly decoded from input buffer by the switch structure above, hence it is set here. + ret.ResetTyp() return } diff --git a/kmip/ttlv/types.go b/kmip/ttlv/types.go index 01bec1a..d41d0aa 100644 --- a/kmip/ttlv/types.go +++ b/kmip/ttlv/types.go @@ -53,7 +53,7 @@ func (com TTL) WriteTTTo(out *bytes.Buffer) { // TTLV structure. Length of value is sum of item lengths including padding. type Structure struct { TTL - Items []interface{} + Items []Item } func (st *Structure) GetTTL() TTL { @@ -66,7 +66,7 @@ func (st *Structure) GetLength() int { // Structure length counts individual item's TTL newLen += LenTTL // Item value length does not include padding - itemLen := item.(Item).GetLength() + itemLen := item.GetLength() // But structure length counts padding newLen += RoundUpTo8(itemLen) } @@ -80,9 +80,8 @@ func (st *Structure) ResetTyp() { // Construct a new structure with the specified tag, place the items inside the structure as well. Each item must be a pointer to Item. func NewStructure(tag Tag, items ...Item) *Structure { - ret := &Structure{TTL: TTL{Tag: tag, Typ: TypStruct}, Items: make([]interface{}, 0, 8)} + ret := &Structure{TTL: TTL{Tag: tag, Typ: TypStruct}, Items: make([]Item, 0, 8)} for _, item := range items { - item.ResetTyp() ret.Items = append(ret.Items, item) } return ret diff --git a/ospackage/etc/sysconfig/cryptctl-client b/ospackage/etc/sysconfig/cryptctl-client index deeb0e7..dd7f270 100644 --- a/ospackage/etc/sysconfig/cryptctl-client +++ b/ospackage/etc/sysconfig/cryptctl-client @@ -20,3 +20,15 @@ KEY_SERVER_PORT=3737 # (Optional) path to PEM-encoded custom certificate authority that issued the TLS certificate for the key server. # Leave empty if the TLS certificate was issued by a well-known certificate authority. TLS_CA_PEM="" + +## Type: string +## Default: "" +# +# (Optional) Location of PEM-encoded TLS certificate file to identify the client to server. +TLS_CERT_PEM="" + +## Type: string +## Default: "" +# +# (Optional) Location of PEM-encoded TLS certificate key file to identify the client to server. +TLS_CERT_KEY_PEM="" diff --git a/ospackage/etc/sysconfig/cryptctl-server b/ospackage/etc/sysconfig/cryptctl-server index 8643a31..bf6458a 100644 --- a/ospackage/etc/sysconfig/cryptctl-server +++ b/ospackage/etc/sysconfig/cryptctl-server @@ -16,6 +16,13 @@ AUTH_PASSWORD_HASH="" # initial setup routine of cryptctl server, hence avoid editing this parameter manually. AUTH_PASSWORD_SALT="" +## Type: string +## Default: "" +# +# (Optional) path to PEM-encoded custom certificate authority that issued the TLS certificate for the key server. +# Leave empty if the TLS certificate was issued by a well-known certificate authority. +TLS_CA_PEM="" + ## Type: string ## Default: "" # @@ -28,6 +35,12 @@ TLS_CERT_PEM="" # Location of PEM-encoded TLS certificate key file that corresponds to the certificate, this is mandatory. TLS_CERT_KEY_PEM="" +## Type: string +## Default: "no" +# +# Whether the server will validate client's certificate before accepting its request. +TLS_VALIDATE_CLIENT="no" + ## Type: string ## Default: "0.0.0.0" # @@ -95,23 +108,42 @@ EMAIL_KEY_RETRIEVAL_GREETING="The key server has given out the following encrypt ## Type: string ## Default: "" # -# Host name or IP of KMIP server, if the key server should act as an KMIP client. +# If key server should act as KMIP proxy, this is the KMIP host name. KMIP_SERVER_HOST="" ## Type: integer ## Default: "" # -# Port number of KMIP service, if the key server should act as an KMIP client. +# If key server should act as KMIP proxy, this is the KMIP service port. KMIP_SERVER_PORT="" ## Type: string ## Default: "" # -# User name for KMIP access, if the key server should act as an KMIP client. +# If key server should act as KMIP proxy, this is the KMIP access user name. KMIP_SERVER_USER="" ## Type: string ## Default: "" # -# Password for KMIP access, if the key server should act as an KMIP client. -KMIP_SERVER_PASS="" \ No newline at end of file +# If key server should act as KMIP proxy, this is the KMIP access password. +KMIP_SERVER_PASS="" + +## Type: string +## Default: "" +# +# If key server should act as KMIP proxy, this is the KMIP server CA certificate. +KMIP_CA_PEM="" + +## Type: string +## Default: "" +# +# If key server should act as KMIP proxy, this is the KMIP client certificate. +KMIP_TLS_CERT_PEM="" + +## Type: string +## Default: "" +# +# If key server should act as KMIP proxy, this is the KMIP client certificate key. +KMIP_TLS_CERT_KEY_PEM="" + diff --git a/ospackage/man/cryptctl.8 b/ospackage/man/cryptctl.8 index f02ded3..99a3d5e 100644 --- a/ospackage/man/cryptctl.8 +++ b/ospackage/man/cryptctl.8 @@ -58,6 +58,11 @@ key server closely tracks its IP, host name, and timestamp, in order to determin the key; if the upper limit number of computers is reached, the key will no longer be handed out automatically; system administrator can always retrieve encryption keys by using key server's access password. +.I cryptctl +can utilise an external key management appliance that understands KMIP v1.3 to store the actual encryption keys. The server +setup sequence asks the system administrator for KMIP connectivity details such as host name, port, certificate, and user +credentials. Of course cryptctl functions perfectly well without using an external key management appliance. + .SH KEY SERVER ACTIONS .SS .TP @@ -83,7 +88,7 @@ the input parameters, and after a final confirmation prompt, encryption routine .IP \n[step] Wipe the partition to be encrypted. .IP \n+[step] -Generate a 512-bit encryption key from cryptographically secure random source. +Generate a 512-bit encryption key from cryptographically secure random source provided by internal (auto-launched) or external KMIP service. .IP \n+[step] Set up LUKS metadata on the partition to encrypt, then make new file system on it, matching the type of that from the directory to encrypt. @@ -154,6 +159,9 @@ to every operation: If you wish to maintain a certificate infrastructure for multiple key servers in a production environment, the YaST Certificate Management program may come in handy. +The server can optionally enforce verification of clients' certificate, should you decide to let server verify clients' +identity before serving them encryption keys. The behaviour is setup during server's and client's setup sequence. + .SH CHANGE/REVOKE OR DELETE ENCRYPTION KEY If you decide to revoke or change encryption key for an encrypted file system, please back up the encrypted data onto a disk and re-run the encryption routine in order to encrypt with a new key. The utility does not provide other means to diff --git a/routine/encrypt.go b/routine/encrypt.go index c48e0c7..47941b9 100644 --- a/routine/encrypt.go +++ b/routine/encrypt.go @@ -7,10 +7,10 @@ import ( "errors" "fmt" "github.com/HouzuoGuo/cryptctl/fs" - "github.com/HouzuoGuo/cryptctl/keydb" - "github.com/HouzuoGuo/cryptctl/keyrpc" + "github.com/HouzuoGuo/cryptctl/keyserv" "github.com/HouzuoGuo/cryptctl/sys" "io" + "log" "os" "path" "path/filepath" @@ -39,10 +39,22 @@ const ( MSG_E_MKDIR = "Failed to make directory \"%s\" - %v" MSG_E_RENAME_DIR = "Failed to rename directory \"%s\" into \"%s\" - %v" MSG_E_NO_DEV_INFO = "Failed to retrieve block device information of \"%s\"" - MSG_E_RPC_KEY_SAVE = "Failed to upload key record: %v" + MSG_E_RPC_KEY_CREATE = "Failed to create an encryption key: %v" MSG_OK_CONGRATS = "\nCongratulations! Data in \"%s\" is now safely encrypted in \"%s\".\nRemember to manually delete the original un-encrypted copy in \"%s\".\n" ) +// Create a new UUID. +func MakeUUID() string { + buf := make([]byte, 16) + _, err := rand.Read(buf) + if err != nil { + log.Panicf("MakeUUID: random source ran dry - %v", err) + } + buf[8] = (buf[8] | 0x80) & 0xBF + buf[6] = (buf[6] | 0x40) & 0x4F + return fmt.Sprintf("%x-%x-%x-%x-%x", buf[0:4], buf[4:6], buf[6:8], buf[8:10], buf[10:]) +} + // Return a computed mapper device name from a crypto device name. func MakeDeviceMapperName(devName string) string { if strings.Contains(devName, "/") { @@ -168,9 +180,32 @@ func EncryptFS(progressOut io.Writer, client *keyserv.CryptClient, return "", err } + // Step 1 - ask server for an encryption key + mountPoints := fs.ParseMtab() + srcDirMount, found := mountPoints.GetMountPointOfPath(srcDir) + if !found { + return "", fmt.Errorf(MSG_E_SRC_DIR_MOUNT_NOT_FOUND, srcDir) + } + cryptDevUUID := MakeUUID() + salt, err := client.GetSalt() + if err != nil { + return "", fmt.Errorf(MSG_E_RPC_KEY_CREATE, err) + } + encryptionKeyResp, err := client.CreateKey(keyserv.CreateKeyReq{ + Password: keyserv.HashPassword(salt, password), + UUID: cryptDevUUID, + MountPoint: srcDir, + MountOptions: srcDirMount.Options, + MaxActive: keyMaxActive, + AliveIntervalSec: keyAliveIntervalSec, + AliveCount: keyAliveCount, + }) + if err != nil { + return "", fmt.Errorf(MSG_E_RPC_KEY_CREATE, err) + } + // Step 1. Un-mount the disk to encrypt fmt.Fprintf(progressOut, MSG_STEP_1, encDisk) - mountPoints := fs.ParseMtab() for { // Repeat until the disk has no more mount points if mountPoint, found := mountPoints.GetByCriteria(encDisk, "", ""); found { @@ -183,23 +218,13 @@ func EncryptFS(progressOut io.Writer, client *keyserv.CryptClient, break } // Step 1 (cont). Wipe the disk and install encryption key - encryptKey := make([]byte, fs.LUKS_KEY_SIZE_I) - _, err = rand.Read(encryptKey) - if err != nil { - return "", fmt.Errorf(MSG_E_NO_RAND, err) - - } - if err := fs.CryptFormat(encryptKey, encDisk); err != nil { + if err := fs.CryptFormat(encryptionKeyResp.KeyContent, encDisk, cryptDevUUID); err != nil { return "", err } dmName := MakeDeviceMapperName(encDisk) - if err := fs.CryptOpen(encryptKey, encDisk, dmName); err != nil { + if err := fs.CryptOpen(encryptionKeyResp.KeyContent, encDisk, dmName); err != nil { return "", err } - srcDirMount, found := mountPoints.GetMountPointOfPath(srcDir) - if !found { - return "", fmt.Errorf(MSG_E_SRC_DIR_MOUNT_NOT_FOUND, srcDir) - } encDiskMapper := path.Join("/dev/mapper", dmName) if err := fs.Format(encDiskMapper, srcDirMount.FileSystem); err != nil { return "", err @@ -243,21 +268,6 @@ func EncryptFS(progressOut io.Writer, client *keyserv.CryptClient, if !found { return "", fmt.Errorf(MSG_E_NO_DEV_INFO, encDisk) } - keyRecord := keydb.Record{ - UUID: cryptDev.UUID, - Key: encryptKey, - MountPoint: srcDir, - MountOptions: srcDirMount.Options, - MaxActive: keyMaxActive, - AliveIntervalSec: keyAliveIntervalSec, - AliveCount: keyAliveCount, - } - if err := client.SaveKey(keyserv.SaveKeyReq{ - Password: password, - Record: keyRecord, - }); err != nil { - return "", fmt.Errorf(MSG_E_RPC_KEY_SAVE, err) - } fmt.Fprintf(progressOut, MSG_OK_CONGRATS, srcDir, encDisk, srcDataDir) return cryptDev.UUID, nil } diff --git a/routine/encrypt_test.go b/routine/encrypt_test.go index 0518c8a..1b570ad 100644 --- a/routine/encrypt_test.go +++ b/routine/encrypt_test.go @@ -8,7 +8,7 @@ import ( "fmt" "github.com/HouzuoGuo/cryptctl/fs" "github.com/HouzuoGuo/cryptctl/keydb" - "github.com/HouzuoGuo/cryptctl/keyrpc" + "github.com/HouzuoGuo/cryptctl/keyserv" "github.com/HouzuoGuo/cryptctl/sys" "io/ioutil" "log" @@ -75,8 +75,8 @@ func TestEncryptDecrypt(t *testing.T) { passHash := keyserv.HashPassword(salt, keyserv.TEST_RPC_PASS) sysconf := keyserv.GetDefaultKeySvcConf() sysconf.Set(keyserv.SRV_CONF_KEYDB_DIR, keydbDir) - sysconf.Set(keyserv.SRV_CONF_TLS_CERT, path.Join(keyserv.PkgInGopath, "keyrpc", "rpc_test.crt")) - sysconf.Set(keyserv.SRV_CONF_TLS_KEY, path.Join(keyserv.PkgInGopath, "keyrpc", "rpc_test.key")) + sysconf.Set(keyserv.SRV_CONF_TLS_CERT, path.Join(keyserv.PkgInGopath, "keyserv", "rpc_test.crt")) + sysconf.Set(keyserv.SRV_CONF_TLS_KEY, path.Join(keyserv.PkgInGopath, "keyserv", "rpc_test.key")) sysconf.Set(keyserv.SRV_CONF_PASS_SALT, hex.EncodeToString(salt[:])) sysconf.Set(keyserv.SRV_CONF_PASS_HASH, hex.EncodeToString(passHash[:])) // To test email notification, simply start postfix at its default configuration @@ -100,16 +100,17 @@ func TestEncryptDecrypt(t *testing.T) { if err != nil { t.Fatal(err) } - go func() { - srv.ListenRPC() - }() + if err := srv.ListenRPC(); err != nil { + t.Fatal(err) + } + go srv.HandleConnections() // Make an RPC client time.Sleep(2 * time.Second) - certContent, err := ioutil.ReadFile(path.Join(keyserv.PkgInGopath, "keyrpc", "rpc_test.crt")) + certContent, err := ioutil.ReadFile(path.Join(keyserv.PkgInGopath, "keyserv", "rpc_test.crt")) if err != nil { t.Fatal(err) } - client, err := keyserv.NewCryptClient("localhost", 3737, certContent) + client, err := keyserv.NewCryptClient("localhost", 3737, certContent, "", "") if err != nil { t.Fatal(err) } @@ -312,9 +313,10 @@ func TestEncryptDecrypt(t *testing.T) { if err != nil { t.Fatal(err) } - go func() { - srv.ListenRPC() - }() + if err := srv.ListenRPC(); err != nil { + t.Fatal(err) + } + go srv.HandleConnections() // There's no need to make a new RPC client because the client does not hold a persistent connection if err := ManOnlineUnlockFS(os.Stdout, client, keyserv.TEST_RPC_PASS); err != nil { @@ -401,9 +403,10 @@ func TestEncryptDecrypt(t *testing.T) { if err != nil { t.Fatal(err) } - go func() { - srv.ListenRPC() - }() + if err := srv.ListenRPC(); err != nil { + t.Fatal(err) + } + go srv.HandleConnections() // Check result from each unlock attempt onlineUnlockErr := make([]error, 5) for i := 0; i < 5; i++ { @@ -421,7 +424,7 @@ func TestEncryptDecrypt(t *testing.T) { checkSecret0() checkSecret1() // Alive messages should have been sent by ReportAlive - if msgs := srv.KeyDB.Records[loop0Dev.UUID].AliveMessages["127.0.0.1"]; len(msgs) == 0 { + if msgs := srv.KeyDB.RecordsByUUID[loop0Dev.UUID].AliveMessages["127.0.0.1"]; len(msgs) == 0 { t.Fatal(msgs) } // Sending alive message to non-existing reports should result in immediate rejection @@ -433,7 +436,7 @@ func TestEncryptDecrypt(t *testing.T) { and their goroutines will end. */ srv.KeyDB.Lock.Lock() - id0Record := srv.KeyDB.Records[encUUID0] + id0Record := srv.KeyDB.RecordsByUUID[encUUID0] id0Record.AliveMessages = map[string][]keydb.AliveMessage{ "NewHost1": []keydb.AliveMessage{ { @@ -449,8 +452,8 @@ func TestEncryptDecrypt(t *testing.T) { Timestamp: time.Now().Unix(), }, }} - srv.KeyDB.Records[encUUID0] = id0Record - id1Record := srv.KeyDB.Records[encUUID1] + srv.KeyDB.RecordsByUUID[encUUID0] = id0Record + id1Record := srv.KeyDB.RecordsByUUID[encUUID1] id1Record.AliveMessages = map[string][]keydb.AliveMessage{ "NewHost1": []keydb.AliveMessage{ { @@ -466,7 +469,7 @@ func TestEncryptDecrypt(t *testing.T) { Timestamp: time.Now().Unix(), }, }} - srv.KeyDB.Records[encUUID1] = id1Record + srv.KeyDB.RecordsByUUID[encUUID1] = id1Record srv.KeyDB.Lock.Unlock() reportAliveMayEnd = true fmt.Println("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@") @@ -483,11 +486,11 @@ func TestEncryptDecrypt(t *testing.T) { */ resetDisks() // Now that server has shut down, try to unlock disks using key records only. - if len(srv.KeyDB.Records) != 2 { - t.Fatal(srv.KeyDB.Records) + if len(srv.KeyDB.RecordsByUUID) != 2 { + t.Fatal(srv.KeyDB.RecordsByUUID) } records := make([]keydb.Record, 0, 0) - for _, record := range srv.KeyDB.Records { + for _, record := range srv.KeyDB.RecordsByUUID { records = append(records, record) fmt.Println("Offline-unlocking", record.MountPoint, record.UUID) if err := UnlockFS(os.Stdout, record); err != nil { @@ -529,8 +532,8 @@ func TestEncryptDecrypt(t *testing.T) { if err := EraseKey(os.Stdout, client, keyserv.TEST_RPC_PASS, encUUID1); err != nil { t.Fatal(err) } - if len(srv.KeyDB.Records) != 0 { - t.Fatal(srv.KeyDB.Records) + if len(srv.KeyDB.RecordsByUUID) != 0 { + t.Fatal(srv.KeyDB.RecordsByUUID) } // Both file systems should now be crypto-closed if _, err := fs.CryptStatus(loop0Crypt); err == nil { diff --git a/routine/unlock.go b/routine/unlock.go index 27cd1d7..17d3a1d 100644 --- a/routine/unlock.go +++ b/routine/unlock.go @@ -7,7 +7,7 @@ import ( "fmt" "github.com/HouzuoGuo/cryptctl/fs" "github.com/HouzuoGuo/cryptctl/keydb" - "github.com/HouzuoGuo/cryptctl/keyrpc" + "github.com/HouzuoGuo/cryptctl/keyserv" "github.com/HouzuoGuo/cryptctl/sys" "io" "os" @@ -37,10 +37,14 @@ func ManOnlineUnlockFS(progressOut io.Writer, client *keyserv.CryptClient, passw return errors.New("Cannot find any more encrypted file systems.") } hostname, _ := sys.GetHostnameAndIP() + salt, err := client.GetSalt() + if err != nil { + return err + } resp, err := client.ManualRetrieveKey(keyserv.ManualRetrieveKeyReq{ UUIDs: reqUUIDs, Hostname: hostname, - Password: password, + Password: keyserv.HashPassword(salt, password), }) if err != nil { return err @@ -245,7 +249,11 @@ func EraseKey(progressOut io.Writer, client *keyserv.CryptClient, password, uuid } // After metadata is erased, ask server to remove its key record as well. hostname, _ := sys.GetHostnameAndIP() - if err := client.EraseKey(keyserv.EraseKeyReq{Password: password, Hostname: hostname, UUID: uuid}); err != nil { + salt, err := client.GetSalt() + if err != nil { + return err + } + if err := client.EraseKey(keyserv.EraseKeyReq{Password: keyserv.HashPassword(salt, password), Hostname: hostname, UUID: uuid}); err != nil { return err } fmt.Fprintf(progressOut, "Encryption header has been wiped successfully, data in \"%s\" (%s) is now irreversibly lost.\n", diff --git a/sys/term.go b/sys/term.go index a5bd547..9adb470 100644 --- a/sys/term.go +++ b/sys/term.go @@ -77,7 +77,7 @@ func InputInt(mandatory bool, defaultHint, lowerLimit, upperLimit int, format st for { valStr := Input(mandatory, strconv.Itoa(defaultHint), format, values...) if valStr == "" { - return 0 + return defaultHint } valInt, err := strconv.Atoi(valStr) if err != nil {