diff --git a/README.md b/README.md index 6445a4f..ffb931c 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,2 @@ # gutils -Another (24183746th) library related to Go utils. Common and useful units which can be consumed to another repo to ease the implementation. +Another (24183746th) library related to Go utils. Common and useful units which can be consumed to another repo to ease the implementation. Separated by sub packages from the original ones. diff --git a/bytex/bytex.go b/bytex/bytex.go new file mode 100644 index 0000000..6e57da6 --- /dev/null +++ b/bytex/bytex.go @@ -0,0 +1,102 @@ +// Package bytex contains byte (string) processing functions. +package bytex + +import ( + "unicode" + "unicode/utf8" +) + +// ByteContainsFold is like bytes.Contains but uses Unicode case-folding. +func ByteContainsFold(s, substr []byte) bool { + return ByteIndexFold(s, substr) >= 0 +} + +// EqualFoldRune compares a and b runes whether they fold equally. +// The code comes from strings.EqualFold, but shortened to only one rune. +func EqualFoldRune(sr, tr rune) bool { + if sr == tr { + return true + } + + // Make sr < tr to simplify what follows. + if tr < sr { + sr, tr = tr, sr + } + + // Fast check for ASCII. + if tr < utf8.RuneSelf && 'A' <= sr && sr <= 'Z' { + // ASCII, and sr is upper case. tr must be lower case. + return tr == sr+'a'-'A' + } + + // General case. SimpleFold(x) returns the next equivalent rune > x + // or wraps around to smaller values. + r := unicode.SimpleFold(sr) + for r != sr && r < tr { + r = unicode.SimpleFold(r) + } + + return r == tr +} + +// ByteIndexFold is like bytes.Contains but uses Unicode case-folding. +func ByteIndexFold(s, substr []byte) int { + if len(substr) == 0 { + return 0 + } + if len(s) == 0 { + return -1 + } + + firstRune := rune(substr[0]) + if firstRune >= utf8.RuneSelf { + firstRune, _ = utf8.DecodeRune(substr) + } + + pos := 0 + for { + rune, size := utf8.DecodeRune(s) + if EqualFoldRune(rune, firstRune) && ByteHasPrefixFold(s, substr) { + return pos + } + pos += size + s = s[size:] + if len(s) == 0 { + break + } + } + + return -1 +} + +// ByteHasPrefixFold is like strings.HasPrefix but uses Unicode case-folding. +func ByteHasPrefixFold(s, prefix []byte) bool { + if len(prefix) == 0 { + return true + } + + for { + pr, prSize := utf8.DecodeRune(prefix) + prefix = prefix[prSize:] + if len(s) == 0 { + return false + } + + // Step with s, too. + sr, size := utf8.DecodeRune(s) + if sr == utf8.RuneError { + return false + } + + s = s[size:] + if !EqualFoldRune(sr, pr) { + return false + } + + if len(prefix) == 0 { + break + } + } + + return true +} diff --git a/bytex/bytex_test.go b/bytex/bytex_test.go new file mode 100644 index 0000000..d52e3da --- /dev/null +++ b/bytex/bytex_test.go @@ -0,0 +1,31 @@ +package bytex + +import "testing" + +func TestByteIndexFold(t *testing.T) { + for i, tc := range []struct { + haystack, needle string + want int + }{ + {"body", "body", 0}, + } { + got := ByteIndexFold([]byte(tc.haystack), []byte(tc.needle)) + if got != tc.want { + t.Errorf("%d. got %d, wanted %d.", i, got, tc.want) + } + } +} + +func TestByteHasPrefixFold(t *testing.T) { + for i, tc := range []struct { + haystack, needle string + want bool + }{ + {"body", "body", true}, + } { + got := ByteHasPrefixFold([]byte(tc.haystack), []byte(tc.needle)) + if got != tc.want { + t.Errorf("%d. got %t, wanted %t.", i, got, tc.want) + } + } +} diff --git a/convertx/convertx.go b/convertx/convertx.go new file mode 100644 index 0000000..4540b99 --- /dev/null +++ b/convertx/convertx.go @@ -0,0 +1,79 @@ +package convertx + +import ( + "encoding/binary" + "fmt" + "reflect" + "strconv" + "strings" +) + +// Accessible to external packages ///////////////////////////////////////////// + +// Atoi ... +func Atoi(str interface{}) int { + return atoi(str) +} + +// Atoui ... +func Atoui(str interface{}) uint { + return atoui(str) +} + +// ToInt64 should convert given value to int64. +func ToInt64(v interface{}) (d int64, err error) { + return toInt64(v) +} + +// BytesToInt64 should convert bytes to int64. +func BytesToInt64(buf []byte) int64 { + return bytesToInt64(buf) +} + +// Int64ToBytes should convert int64 to bytes. +func Int64ToBytes(i int64) []byte { + return int64ToBytes(i) +} + +// Underlying functions //////////////////////////////////////////////////////// + +func atoi(str interface{}) (i int) { + if str == nil { + return 0 + } + i, _ = strconv.Atoi(strings.Trim(str.(string), " ")) + return +} + +func atoui(str interface{}) uint { + if str == nil { + return 0 + } + u, _ := strconv.Atoi(strings.Trim(str.(string), " ")) + return uint(u) +} + +func toInt64(v interface{}) (d int64, err error) { + val := reflect.ValueOf(v) + switch v.(type) { + case int, int8, int16, int32, int64: + d = val.Int() + case uint, uint8, uint16, uint32, uint64: + d = int64(val.Uint()) + default: + err = fmt.Errorf("ToInt64 need numeric not `%T`", v) + } + return +} + +func bytesToInt64(buf []byte) int64 { + return int64(binary.BigEndian.Uint64(buf)) +} + +func int64ToBytes(i int64) []byte { + var buf = make([]byte, 8) + binary.BigEndian.PutUint64(buf, uint64(i)) + return buf +} + +//////////////////////////////////////////////////////////////////////////////// diff --git a/cryptox/cryptox.go b/cryptox/cryptox.go new file mode 100644 index 0000000..c99ef3a --- /dev/null +++ b/cryptox/cryptox.go @@ -0,0 +1,244 @@ +package cryptox + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/base64" + "encoding/json" + "errors" + "io" + "io/ioutil" + "os" + "sync" + "time" + + "golang.org/x/crypto/nacl/secretbox" + "golang.org/x/crypto/scrypt" +) + +const minL2N = 14 + +var ( + // DefaultTimeout ... + DefaultTimeout = 5 * time.Second + dursMu sync.Mutex + durs = make([]time.Duration, 0, 8) +) + +type ( + // WriteCloser ... + WriteCloser struct { + io.Writer + } + + // SecretWriter ... + SecretWriter struct { + key, nonce []byte + w io.WriteCloser + buf bytes.Buffer + } + + // Key ... + Key struct { + Bytes []byte `json:"-"` + Salt []byte + L2N uint + R, P int + } +) + +// Close ... +func (wc WriteCloser) Close() error { + if c, ok := wc.Writer.(io.Closer); ok { + return c.Close() + } + return nil +} + +// Write ... +func (sw *SecretWriter) Write(p []byte) (int, error) { + return sw.buf.Write(p) +} + +// Close ... +func (sw *SecretWriter) Close() error { + var ( + key [32]byte + nonce [24]byte + ) + + copy(key[:], sw.key) + copy(nonce[:], sw.nonce) + out := make([]byte, 0, sw.buf.Len()+secretbox.Overhead) + _, err := sw.w.Write(secretbox.Seal(out, sw.buf.Bytes(), &nonce, &key)) + if err != nil { + return err + } + + return sw.w.Close() +} + +func (key Key) String() string { + k := struct { + Bytes, Salt []byte + L2N uint + R, P int + }(key) + b, err := json.Marshal(k) + if err != nil { + return err.Error() + } + return string(b) +} + +// Populate ... +func (key *Key) Populate(password []byte, keyLen int) error { + var err error + key.Bytes, err = scrypt.Key(password, key.Salt, 1< timeout { + key.L2N-- + } + break + } + deadline := time.Now().Add(timeout) + for now := time.Now(); now.Before(deadline); { + if key.Bytes, err = scrypt.Key(password, salt, 1< i { + durs = durs[:i+1] + } else { + durs = append(durs, make([]time.Duration, len(durs))...) + } + } + durs[key.L2N-minL2N] = dur + now = now2 + + if now.Add(2 * dur).After(deadline) { + break + } + key.L2N++ + } + return key, nil +} + +// Encrypt binary data to a base64 string with AES using the key provided. +func Encrypt(keyString string, data []byte) (string, error) { + key := []byte(keyString) + + block, err := aes.NewCipher(key) + if err != nil { + return "", err + } + + ciphertext := make([]byte, aes.BlockSize+len(data)) + iv := ciphertext[:aes.BlockSize] + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + return "", err + } + + stream := cipher.NewCFBEncrypter(block, iv) + stream.XORKeyStream(ciphertext[aes.BlockSize:], data) + + return base64.URLEncoding.EncodeToString(ciphertext), nil +} + +// Decrypt a base64 string to binary data with AES using the key provided. +func Decrypt(keyString string, base64Data string) ([]byte, error) { + key := []byte(keyString) + ciphertext, err := base64.URLEncoding.DecodeString(base64Data) + if err != nil { + return nil, err + } + + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + if len(ciphertext) < aes.BlockSize { + return nil, errors.New("Ciphertext provided is smaller than AES block size") + } + + iv := ciphertext[:aes.BlockSize] + ciphertext = ciphertext[aes.BlockSize:] + stream := cipher.NewCFBDecrypter(block, iv) + stream.XORKeyStream(ciphertext, ciphertext) + + return ciphertext, nil +} diff --git a/cryptox/cryptox_test.go b/cryptox/cryptox_test.go new file mode 100644 index 0000000..7581135 --- /dev/null +++ b/cryptox/cryptox_test.go @@ -0,0 +1,89 @@ +package cryptox + +import ( + "fmt" + "testing" +) + +const ( + testKeyString = "abcd1234abcd1234" + testPlaintext = "HELLO WORLD" + testCiphertext = "eTtvSIEnXOL6rhMSznY6HgfntkuWHZA16Z_s" +) + +func TestInvalidKeysAndData(t *testing.T) { + // Encrypt with empty key + data := []byte(testPlaintext) + _, err := Encrypt("", data) + if err == nil { + t.Errorf("Encrypt succeeded with an empty key") + } + + // Decrypt with empty key + _, err = Decrypt("", testCiphertext) + if err == nil { + t.Errorf("Decrypt succeeded with an empty key") + } + + // Decrypt with an incorrect key + res, err := Decrypt("aaaaffff12345678", testCiphertext) + if err != nil { + t.Errorf("Decrypt failed: %v", err) + } + if string(res) == testPlaintext { + t.Errorf("Decrypt for '%s' succeeded with an incorrect key", testPlaintext) + } + + // Decrypt with the correct key + res, err = Decrypt(testKeyString, testCiphertext) + if err != nil { + t.Errorf("Decrypt failed: %v", err) + } + if string(res) != testPlaintext { + t.Errorf("Decrypt for '%s' failed with an correct key", testPlaintext) + } + + // Decrypt an short string (i.e. smaller than block size) + _, err = Decrypt(testKeyString, "aaaabbbbcccc") + if err == nil { + t.Errorf("Decrypt succeeded with an invalid key size") + } + + // Decrypt a non-base64 string + _, err = Decrypt(testKeyString, fmt.Sprintf("%s#@?`", testCiphertext)) + if err == nil { + t.Errorf("Decrypt succeeded with an invalid base64 string") + } +} + +func TestEncryptAndDecrypt(t *testing.T) { + tests := []string{ + testPlaintext, + "I love jam", + testPlaintext, + "pls work already", + "", + "haaaalp", + } + + for i, testData := range tests { + t.Run(fmt.Sprintf("test-%v", i), func(t *testing.T) { + data := []byte(testData) + cipher, err := Encrypt(testKeyString, data) + if err != nil { + t.Fatal(err) + } + if string(cipher) == "" { + t.Fatalf("Encrypt failed, cipher result is empty string") + } + + plain, err := Decrypt(testKeyString, cipher) + if err != nil { + t.Fatal(err) + } + if plain == nil { + t.Fatal("Decrypt failed, plaintext result is nil") + } + }) + } +} diff --git a/errorsx/errorsx.go b/errorsx/errorsx.go new file mode 100644 index 0000000..cd31999 --- /dev/null +++ b/errorsx/errorsx.go @@ -0,0 +1,391 @@ +package errorsx + +import ( + "fmt" + "net/http" + "strings" +) + +// DomainError represents an error thrown by the domain. +type DomainError struct { + Status int `json:"status_code"` + Message string `json:"error"` +} + +// NewDomainError holds the value of status code, message and errors. +func NewDomainError(status int, message string) *DomainError { + return &DomainError{status, message} +} + +// Error should return the string value of the error. +func (err *DomainError) Error() string { + return err.Message +} + +// 1xx ------------------------------------------------------------------------- + +// Continue should return `http.StatusContinue` with custom message. +func Continue(message string) *DomainError { // 100 + return &DomainError{http.StatusContinue, message} +} + +// SwitchingProtocols should return `http.StatusSwitchingProtocols` with custom +// message. +func SwitchingProtocols(message string) *DomainError { // 101 + return &DomainError{http.StatusSwitchingProtocols, message} +} + +// Processing should return `http.StatusProcessing` with custom message. +func Processing(message string) *DomainError { // 102 + return &DomainError{http.StatusProcessing, message} +} + +// 2xx ------------------------------------------------------------------------- + +// OK should return `http.StatusOK` with custom message. +func OK(message string) *DomainError { // 200 + return &DomainError{http.StatusOK, message} +} + +// Created should return `http.StatusOK` with custom message. +func Created(message string) *DomainError { // 201 + return &DomainError{http.StatusCreated, message} +} + +// Accepted should return `http.StatusAccepted` with custom message. +func Accepted(message string) *DomainError { // 202 + return &DomainError{http.StatusAccepted, message} +} + +// NonAuthoritativeInfo should return `http.StatusNonAuthoritativeInfo` +// with custom message. +func NonAuthoritativeInfo(message string) *DomainError { // 203 + return &DomainError{http.StatusNonAuthoritativeInfo, message} +} + +// NoContent should return `http.StatusNoContent` with custom message. +func NoContent(message string) *DomainError { // 204 + return &DomainError{http.StatusNoContent, message} +} + +// ResetContent should return `http.StatusResetContent` with custom message. +func ResetContent(message string) *DomainError { // 205 + return &DomainError{http.StatusResetContent, message} +} + +// PartialContent should return `http.StatusPartialContent` with custom message. +func PartialContent(message string) *DomainError { // 206 + return &DomainError{http.StatusPartialContent, message} +} + +// MultiStatus should return `http.StatusMultiStatus` with custom message. +func MultiStatus(message string) *DomainError { // 207 + return &DomainError{http.StatusMultiStatus, message} +} + +// AlreadyReported should return `http.StatusAlreadyReported` with custom message. +func AlreadyReported(message string) *DomainError { // 208 + return &DomainError{http.StatusAlreadyReported, message} +} + +// IMUsed should return `http.StatusIMUsed` with custom message. +func IMUsed(message string) *DomainError { // 209 + return &DomainError{http.StatusIMUsed, message} +} + +// 3xx ------------------------------------------------------------------------- + +// MultipleChoices should return `http.StatusMultipleChoices` with custom message. +func MultipleChoices(message string) *DomainError { // 300 + return &DomainError{http.StatusMultipleChoices, message} +} + +// MovedPermanently should return `http.StatusMovedPermanently` with custom message. +func MovedPermanently(message string) *DomainError { // 301 + return &DomainError{http.StatusMovedPermanently, message} +} + +// Found should return `http.StatusFound` with custom message. +func Found(message string) *DomainError { // 302 + return &DomainError{http.StatusFound, message} +} + +// SeeOther should return `http.StatusSeeOther` with custom message. +func SeeOther(message string) *DomainError { // 303 + return &DomainError{http.StatusSeeOther, message} +} + +// NotModified should return `http.StatusNotModified` with custom message. +func NotModified(message string) *DomainError { // 304 + return &DomainError{http.StatusNotModified, message} +} + +// UseProxy should return `http.StatusUseProxy` with custom message. +func UseProxy(message string) *DomainError { // 305 + return &DomainError{http.StatusUseProxy, message} +} + +// TemporaryRedirect should return `http.Status` with custom message. +func TemporaryRedirect(message string) *DomainError { // 307 + return &DomainError{http.StatusTemporaryRedirect, message} +} + +// PermanentRedirect should return `http.Status` with custom message. +func PermanentRedirect(message string) *DomainError { // 308 + return &DomainError{http.StatusPermanentRedirect, message} +} + +// 4xx ------------------------------------------------------------------------- + +// BadRequest should return `http.StatusBadRequest` with custom message. +func BadRequest(message string) *DomainError { // 400 + return &DomainError{http.StatusBadRequest, message} +} + +// Unauthorized should return `http.StatusUnauthorized` with custom message. +func Unauthorized(message string) *DomainError { // 401 + return &DomainError{http.StatusUnauthorized, message} +} + +// PaymentRequired should return `http.PaymentRequired` with custom message. +func PaymentRequired(message string) *DomainError { // 402 + return &DomainError{http.StatusPaymentRequired, message} +} + +// Forbidden should return `http.StatusForbidden` with custom message. +func Forbidden(message string) *DomainError { // 403 + return &DomainError{http.StatusForbidden, message} +} + +// NotFound should return `http.StatusNotFound` with custom message. +func NotFound(message string) *DomainError { // 404 + return &DomainError{http.StatusNotFound, message} +} + +// MethodNotAllowed should return `http.Status` with custom message. +func MethodNotAllowed(message string) *DomainError { // 405 + return &DomainError{http.StatusMethodNotAllowed, message} +} + +// NotAcceptable should return `http.StatusNotAcceptable` with custom message. +func NotAcceptable(message string) *DomainError { // 406 + return &DomainError{http.StatusNotAcceptable, message} +} + +// ProxyAuthRequired should return `http.StatusProxyAuthRequired` with custom message. +func ProxyAuthRequired(message string) *DomainError { // 407 + return &DomainError{http.StatusProxyAuthRequired, message} +} + +// RequestTimeout should return `http.Status` with custom message. +func RequestTimeout(message string) *DomainError { // 408 + return &DomainError{http.StatusRequestTimeout, message} +} + +// Conflict should return `http.StatusConflict` with custom message. +func Conflict(message string) *DomainError { // 409 + return &DomainError{http.StatusConflict, message} +} + +// Gone should return `http.Status` with custom message. +func Gone(message string) *DomainError { // 410 + return &DomainError{http.StatusGone, message} +} + +// LengthRequired should return `http.Status` with custom message. +func LengthRequired(message string) *DomainError { // 411 + return &DomainError{http.StatusLengthRequired, message} +} + +// PreconditionFailed should return `http.Status` with custom message. +func PreconditionFailed(message string) *DomainError { // 412 + return &DomainError{http.StatusPreconditionFailed, message} +} + +// RequestEntityTooLarge should return `http.Status` with custom message. +func RequestEntityTooLarge(message string) *DomainError { // 413 + return &DomainError{http.StatusRequestEntityTooLarge, message} +} + +// RequestURITooLong should return `http.Status` with custom message. +func RequestURITooLong(message string) *DomainError { // 414 + return &DomainError{http.StatusRequestURITooLong, message} +} + +// UnsupportedMediaType should return `http.StatusUnsupportedMediaType` with +// custom message. +func UnsupportedMediaType(message string) *DomainError { // 415 + return &DomainError{http.StatusUnsupportedMediaType, message} +} + +// RequestedRangeNotSatisfiable should return +// `http.StatusRequestedRangeNotSatisfiable` with custom message. +func RequestedRangeNotSatisfiable(message string) *DomainError { // 416 + return &DomainError{http.StatusRequestedRangeNotSatisfiable, message} +} + +// ExpectationFailed should return `http.StatusExpectationFailed` +// with custom message. +func ExpectationFailed(message string) *DomainError { // 418 + return &DomainError{http.StatusExpectationFailed, message} +} + +// Teapot should return `http.StatusTeapot` with custom message. +func Teapot(message string) *DomainError { // 418 + return &DomainError{http.StatusTeapot, message} +} + +// UnprocessableEntity should return `http.StatusUnprocessableEntity` with +// custom message. +func UnprocessableEntity(message string) *DomainError { // 422 + return &DomainError{http.StatusUnprocessableEntity, message} +} + +// Locked should return `http.StatusLocked` with custom message. +func Locked(message string) *DomainError { // 423 + return &DomainError{http.StatusLocked, message} +} + +// FailedDependency should return `http.StatusFailedDependency` +// with custom message. +func FailedDependency(message string) *DomainError { // 424 + return &DomainError{http.StatusFailedDependency, message} +} + +// UpgradeRequired should return `http.StatusUpgradeRequired` +// with custom message. +func UpgradeRequired(message string) *DomainError { // 426 + return &DomainError{http.StatusUpgradeRequired, message} +} + +// PreconditionRequired should return `http.StatusPreconditionRequired` +// with custom message. +func PreconditionRequired(message string) *DomainError { // 428 + return &DomainError{http.StatusPreconditionRequired, message} +} + +// TooManyRequests should return `http.Status` with custom message. +func TooManyRequests(message string) *DomainError { // 429 + return &DomainError{http.StatusTooManyRequests, message} +} + +// RequestHeaderFieldsTooLarge should return +// `http.StatusRequestHeaderFieldsTooLarge` with custom message. +func RequestHeaderFieldsTooLarge(message string) *DomainError { // 431 + return &DomainError{http.StatusRequestHeaderFieldsTooLarge, message} +} + +// UnavailableForLegalReasons should return +// `http.StatusUnavailableForLegalReasons` with custom message. +func UnavailableForLegalReasons(message string) *DomainError { // 451 + return &DomainError{http.StatusUnavailableForLegalReasons, message} +} + +// 5xx ------------------------------------------------------------------------- + +// InternalServer should return `http.StatusInternalServerError` +// with custom message. +func InternalServer(message string) *DomainError { // 500 + return &DomainError{http.StatusInternalServerError, message} +} + +// NotImplemented should return `http.StatusNotImplemented` with custom message. +func NotImplemented(message string) *DomainError { // 501 + return &DomainError{http.StatusNotImplemented, message} +} + +// BadGateway should return `http.Status` with custom message. +func BadGateway(message string) *DomainError { // 502 + return &DomainError{http.StatusBadGateway, message} +} + +// ServiceUnavailable should return `http.StatusServiceUnavailable` +// with custom message. +func ServiceUnavailable(message string) *DomainError { // 503 + return &DomainError{http.StatusServiceUnavailable, message} +} + +// GatewayTimeout should return `http.StatusGatewayTimeout` +// with custom message. +func GatewayTimeout(message string) *DomainError { // 504 + return &DomainError{http.StatusGatewayTimeout, message} +} + +// HTTPVersionNotSupported should return `http.StatusHTTPVersionNotSupported` +// with custom message. +func HTTPVersionNotSupported(message string) *DomainError { // 505 + return &DomainError{http.StatusHTTPVersionNotSupported, message} +} + +// VariantAlsoNegotiates should return `http.StatusVariantAlsoNegotiates` +// with custom message. +func VariantAlsoNegotiates(message string) *DomainError { // 506 + return &DomainError{http.StatusVariantAlsoNegotiates, message} +} + +// InsufficientStorage should return `http.StatusInsufficientStorage` +// with custom message. +func InsufficientStorage(message string) *DomainError { // 5 + return &DomainError{http.StatusInsufficientStorage, message} +} + +// LoopDetected should return `http.StatusLoopDetected` with custom message. +func LoopDetected(message string) *DomainError { // 508 + return &DomainError{http.StatusLoopDetected, message} +} + +// NotExtended should return `http.StatusNotExtended` with custom message. +func NotExtended(message string) *DomainError { // 510 + return &DomainError{http.StatusNotExtended, message} +} + +// NetworkAuthenticationRequired should return +// `http.StatusNetworkAuthenticationRequired` with custom message. +func NetworkAuthenticationRequired(message string) *DomainError { // 511 + return &DomainError{http.StatusNetworkAuthenticationRequired, message} +} + +// ----------------------------------------------------------------------------- + +// IsStatusNotModified should return true if HTTP status of an error is 204. +func (err *DomainError) IsStatusNotModified() bool { + return err.Status == http.StatusNotModified +} + +// IsStatusBadRequest should return true if HTTP status of an error is 400. +func (err *DomainError) IsStatusBadRequest() bool { + return err.Status == http.StatusBadRequest +} + +// IsStatusUnauthorized should return true if HTTP status of an error is 401. +func (err *DomainError) IsStatusUnauthorized() bool { + return err.Status == http.StatusUnauthorized +} + +// IsStatusNotFound should return true if HTTP status of an error is 404. +func (err *DomainError) IsStatusNotFound() bool { + return err.Status == http.StatusNotFound +} + +// IsStatusConflict should return true if HTTP status of an error is 409. +func (err *DomainError) IsStatusConflict() bool { + return err.Status == http.StatusConflict +} + +// ----------------------------------------------------------------------------- + +// NotUniqueTogether should return error for unique together fields. +func NotUniqueTogether(ss ...string) error { + if len(ss) == 0 { + return nil + } + + return fmt.Errorf("'%s' must be unique", strings.Join(ss, ", ")) +} + +// NotUnique should return an error for unique together fields. +func NotUnique(s string) error { + return NotUniqueTogether(s) +} + +// ----------------------------------------------------------------------------- diff --git a/filepathx/constants.go b/filepathx/constants.go new file mode 100644 index 0000000..e3a01d2 --- /dev/null +++ b/filepathx/constants.go @@ -0,0 +1,6 @@ +package filepathx + +const ( + // StrUnderscore holds the string value of underscore. + StrUnderscore = "_" +) diff --git a/filepathx/filepathx.go b/filepathx/filepathx.go new file mode 100644 index 0000000..f5a292b --- /dev/null +++ b/filepathx/filepathx.go @@ -0,0 +1,557 @@ +// FIXME: +// - Doc is not proper. +// - Test cases are remaining. + +// Package filepathx provides extra utilities for file and filepath related operations. +package filepathx + +import ( + "bufio" + "errors" + "io" + "io/ioutil" + "log" + "os" + "os/exec" + "path" + "path/filepath" + "regexp" + "strings" + + "github.com/aljiwala/gutils/strx" +) + +var curpath = SelfDir() + +type dirNames struct { + names []string + err error +} + +// Dir gets directory name of the filepath. +func Dir(file string) string { + return path.Dir(file) +} + +// Ext should return extension of the given file. +func Ext(file string) string { + return path.Ext(file) +} + +// Rename should rename the file. +func Rename(file string, to string) error { + return os.Rename(file, to) +} + +// Remove should remove the file. +func Remove(file string) error { + return os.Remove(file) +} + +// Basename gets base name of provided filepath. +func Basename(file string) string { + return path.Base(file) +} + +// SelfPath gets compiled executable file absolute path. +func SelfPath() (path string) { + path, _ = filepath.Abs(os.Args[0]) + return +} + +// SelfDir gets compiled executable file directory. +func SelfDir() string { + return filepath.Dir(SelfPath()) +} + +// SelfChdir switch the working path to my own path. +func SelfChdir() { + if err := os.Chdir(curpath); err != nil { + log.Fatal(err) + } +} + +// IsDir should return true and nil as error if provided path is of directory. +func IsDir(path string) (bool, error) { + // Stat returns a FileInfo describing the named file. + // If there is an error, it will be of type *PathError. + fileInfo, err := os.Stat(path) + if err != nil { + // no such file or dir + return false, errors.New("No such file or directory") + } + + // Return true if it's directory. + if fileInfo.IsDir() { + return true, nil + } + return false, nil +} + +// GetWDPath gets the work directory path. +func GetWDPath() (wd string) { + wd = os.Getenv("GOPATH") + if wd == "" { + panic("GOPATH is not setted in env.") + } + return +} + +// DoesExist checks whether a file or directory exists. It returns false if not. +func DoesExist(path string) bool { + // Stat returns a FileInfo describing the named file. + // If there is an error, it will be of type *PathError. + _, err := os.Stat(path) + return err == nil || os.IsExist(err) +} + +// DoesNotExist should check if file/folder doesn't exist on provided path. +func DoesNotExist(path string) bool { + // Stat returns a FileInfo describing the named file. + // If there is an error, it will be of type *PathError. + _, err := os.Stat(path) + + // If the file doesn't exists, we will get an error. + // Thus, we can use this to check: + if err != nil { + // IsNotExist returns a boolean indicating whether the error is known to + // report that a file or directory does not exist. + // It's satisfied by ErrNotExist as well as some syscall errors. + if os.IsNotExist(err) { + // Not exist. + return true + } + } + + // Exist. + return false +} + +// BasePath should return base file path. +func BasePath(path string) string { + n := strings.LastIndexByte(path, '.') + if n > 0 { + return path[:n] + } + return path +} + +// RelPath returns a relative path that is lexically equivalent to targpath. +func RelPath(targpath string) string { + basepath, _ := filepath.Abs("./") + rel, _ := filepath.Rel(basepath, targpath) + return strings.Replace(rel, `\`, `/`, -1) +} + +// // RealPath gets absolute filepath, based on built executable file. +// func RealPath(file string) (string, error) { +// if path.IsAbs(file) { +// return file, nil +// } +// wd, err := os.Getwd() +// return path.Join(wd, file), err +// } + +// FileMTime should get and return modified time of the file. +func FileMTime(file string) (int64, error) { + f, e := os.Stat(file) + if e != nil { + return 0, e + } + return f.ModTime().Unix(), nil +} + +// FileSize should get return size as bytes. +func FileSize(file string) (int64, error) { + f, e := os.Stat(file) + if e != nil { + return 0, e + } + return f.Size(), nil +} + +// IsFile checks whether the path is a file, +// it returns false when it's a directory or does not exist. +func IsFile(filePath string) bool { + f, e := os.Stat(filePath) + if e != nil { + return false + } + return !f.IsDir() +} + +// DirsUnder should return list of dirs under given dirPath. +func DirsUnder(dirPath string) (container []string, err error) { + var fs []os.FileInfo + if !DoesExist(dirPath) { + return + } + + fs, err = ioutil.ReadDir(dirPath) + if err != nil { + return + } + + sz := len(fs) + if sz == 0 { + return + } + + for i := 0; i < sz; i++ { + if fs[i].IsDir() { + name := fs[i].Name() + if name != "." && name != ".." { + container = append(container, name) + } + } + } + + return +} + +// FilesUnder should return list of files under given dirPath. +func FilesUnder(dirPath string) (container []string, err error) { + var fs []os.FileInfo + if !DoesExist(dirPath) { + return + } + + fs, err = ioutil.ReadDir(dirPath) + if err != nil { + return + } + + sz := len(fs) + if sz == 0 { + return + } + + for i := 0; i < sz; i++ { + if !fs[i].IsDir() { + container = append(container, fs[i].Name()) + } + } + + return +} + +// SearchFile searches a file in paths. +// Often used in search config file in `/etc`. +func SearchFile(filename string, paths ...string) (fullPath string, err error) { + for _, path := range paths { + if fullPath = filepath.Join(path, filename); DoesExist(fullPath) { + return + } + } + err = errors.New(fullPath + " not found in paths") + return +} + +// IsDirExists should return true if provided path is directory. +func IsDirExists(path string) bool { + fi, err := os.Stat(path) + if err != nil { + return os.IsExist(err) + } + return fi.IsDir() +} + +// IsBinaryExist searches for an executable binary named file in the directories. +func IsBinaryExist(binary string) bool { + if _, err := exec.LookPath(binary); err != nil { + return false + } + return true +} + +// CreateDir should create directory (if it doesn't exist) based on provided path. +func CreateDir(dirStr string) error { + if _, err := os.Stat(dirStr); os.IsNotExist(err) { + err := os.MkdirAll(dirStr, 0777) + if err != nil { + log.Printf("Failed to create directory `%s`: %v", dirStr, err) + return err + } + } + return nil +} + +// FilenameReplace should replace illegal filename with similar characters. +func FilenameReplace(filename string) (rfn string) { + // Replace “” with "". + if strings.Count(filename, `"`) > 0 { + var i = 1 + label: + for k, v := range []byte(filename) { + if string(v) != `"` { + continue + } + if i%2 == 1 { + filename = string(filename[:k]) + `“` + string(filename[k+1:]) + } else { + filename = string(filename[:k]) + `”` + string(filename[k+1:]) + } + i++ + goto label + } + } + + replace := strings.Replace + rfn = replace(filename, `:`, `:`, -1) + rfn = replace(rfn, `*`, `ж`, -1) + rfn = replace(rfn, `<`, `<`, -1) + rfn = replace(rfn, `>`, `>`, -1) + rfn = replace(rfn, `?`, `?`, -1) + rfn = replace(rfn, `/`, `/`, -1) + rfn = replace(rfn, `|`, `∣`, -1) + rfn = replace(rfn, `\`, `╲`, -1) + return +} + +// ExcelSheetNameReplace should replace the illegal characters in the excel +// worksheet name with the underscore. +func ExcelSheetNameReplace(filename string) (rfn string) { + us := StrUnderscore + replace := strings.Replace + rfn = replace(filename, `:`, us, -1) + rfn = replace(rfn, `:`, ``, -1) + rfn = replace(rfn, `*`, us, -1) + rfn = replace(rfn, `?`, us, -1) + rfn = replace(rfn, `?`, us, -1) + rfn = replace(rfn, `/`, us, -1) + rfn = replace(rfn, `/`, us, -1) + rfn = replace(rfn, `\`, us, -1) + rfn = replace(rfn, `╲`, us, -1) + rfn = replace(rfn, `]`, us, -1) + rfn = replace(rfn, `[`, us, -1) + return +} + +func grepFile(pattern, filename string) (lines []string, err error) { + var isLongLine bool + re, err := regexp.Compile(pattern) + if err != nil { + return + } + + fd, err := os.Open(filename) + if err != nil { + return + } + + prefix := "" + lines = make([]string, 0) + reader := bufio.NewReader(fd) + + for { + byteLine, isPrefix, er := reader.ReadLine() + + if er != nil && er != io.EOF { + return nil, er + } + if er == io.EOF { + break + } + + line := string(byteLine) + if isPrefix { + prefix += line + continue + } else { + isLongLine = true + } + + line = prefix + line + if isLongLine { + prefix = "" + } + + if re.MatchString(line) { + lines = append(lines, line) + } + } + + return +} + +// GrepFile like command grep -E +// for example: GrepFile(`^hello`, "hello.txt") +// \n is striped while read +func GrepFile(pattern string, filename string) (lines []string, err error) { + return grepFile(pattern, filename) +} + +// walk recursively descends path, calling w. +func walk(path string, info os.FileInfo, walkFn filepath.WalkFunc, followSymlinks bool) error { + stat := os.Lstat + if followSymlinks { + stat = os.Stat + } + err := walkFn(path, info, nil) + if err != nil { + if info.IsDir() && err == filepath.SkipDir { + return nil + } + return err + } + + if !info.IsDir() { + return nil + } + + c, err := readDirNames(path) + if err != nil { + return walkFn(path, info, err) + } + + for names := range c { + if names.err != nil { + return walkFn(path, info, names.err) + } + for _, name := range names.names { + filename := filepath.Join(path, name) + fileInfo, err := stat(filename) + if err != nil { + if err = walkFn(filename, fileInfo, err); err != nil && err != filepath.SkipDir { + return err + } + } else { + err = walk(filename, fileInfo, walkFn, followSymlinks) + if err != nil { + if !fileInfo.IsDir() || err != filepath.SkipDir { + return err + } + } + } + } + } + return nil +} + +// Walk walks the file tree rooted at root, calling walkFn for each file or +// directory in the tree, including root. All errors that arise visiting files +// and directories are filtered by walkFn. The files are walked UNORDERED, +// which makes the output undeterministic! +// Walk does not follow symbolic links. +func Walk(root string, walkFn filepath.WalkFunc) error { + info, err := os.Lstat(root) + if err != nil { + return walkFn(root, nil, err) + } + return walk(root, info, walkFn, false) +} + +// WalkWithSymlinks walks the file tree rooted at root, calling walkFn for each file or +// directory in the tree, including root. All errors that arise visiting files +// and directories are filtered by walkFn. The files are walked UNORDERED, +// which makes the output undeterministic! +// WalkWithSymlinks does follow symbolic links! +func WalkWithSymlinks(root string, walkFn filepath.WalkFunc) error { + info, err := os.Stat(root) + if err != nil { + return walkFn(root, nil, err) + } + return walk(root, info, walkFn, true) +} + +// WalkDirs traverses the directory, return to the relative path. +// You can specify the suffix. +func WalkDirs(targpath string, suffixes ...string) (dirlist []string) { + if !filepath.IsAbs(targpath) { + targpath, _ = filepath.Abs(targpath) + } + + err := filepath.Walk(targpath, func(retpath string, f os.FileInfo, err error) error { + if err != nil { + return err + } + + if !f.IsDir() { + return nil + } + + if len(suffixes) == 0 { + dirlist = append(dirlist, RelPath(retpath)) + return nil + } + + _retpath := RelPath(retpath) + for _, suffix := range suffixes { + if strings.HasSuffix(_retpath, suffix) { + dirlist = append(dirlist, _retpath) + } + } + + return nil + }) + + if err != nil { + log.Printf("utils.WalkRelDirs: %v\n", err) + return + } + + return +} + +// Converters, Readers & Writers ----------------------------------------------- + +// ReadFileToBytes reads data type '[]byte' from file by given path. It returns +// error when fail to finish operation. +func ReadFileToBytes(filePath string) (b []byte, err error) { + b, err = ioutil.ReadFile(filePath) + if err != nil { + return + } + return +} + +// ReadFileToString reads data type 'string' from file by given path. It returns +// error when fail to finish operation. +func ReadFileToString(filePath string, escapeNL bool) (string, error) { + b, err := ReadFileToBytes(filePath) + if err != nil { + return "", err + } + + str := string(b) + if escapeNL { + return strx.TrimRightSpace(str), nil + } + return str, nil +} + +// Helpers --------------------------------------------------------------------- + +// readDirNames reads the directory named by dirname and returns +// a channel for future results. +func readDirNames(dirname string) (<-chan dirNames, error) { + f, err := os.Open(dirname) + if err != nil { + return nil, err + } + c := make(chan dirNames) + + go func() { + defer f.Close() + defer close(c) + + for { + names, err := f.Readdirnames(1024) + if err != nil { + if err == io.EOF { + if len(names) > 0 { + c <- dirNames{names: names} + } + return + } + c <- dirNames{err: err} + return + } + c <- dirNames{names: names} + } + }() + + return c, nil +} + +// ----------------------------------------------------------------------------- diff --git a/httpx/httpx.go b/httpx/httpx.go new file mode 100644 index 0000000..1e71ad2 --- /dev/null +++ b/httpx/httpx.go @@ -0,0 +1,680 @@ +package httpx + +import ( + "bytes" + "fmt" + "io" + "io/ioutil" + "mime/multipart" + "net/http" + "net/http/httputil" + "net/textproto" + "net/url" + "os" + "path/filepath" + "strings" + "time" + + "github.com/aljiwala/gutils/ioutilx" + "github.com/aljiwala/gutils/utilsx" + "github.com/pkg/errors" +) + +// ErrNotOK is used when the status code is not 200 OK. +type ErrNotOK struct { + URL string + Err string +} + +func (e ErrNotOK) Error() string { + return fmt.Sprintf("code %v while downloading %v", e.Err, e.URL) +} + +// EnsureHTTPS wraps a HTTP handler and ensures that it was requested over HTTPS. +// If "DISABLED_ENSURE_HTTPS" is in the environment and set to either "1" or "true", +// then EnsureHTTPS should always pass. +func EnsureHTTPS(handler http.HandlerFunc) http.HandlerFunc { + v := os.Getenv("DISABLE_ENSURE_HTTPS") + disabled := v == "1" || v == "true" + return func(w http.ResponseWriter, r *http.Request) { + if !disabled && (r.URL.Scheme != "https" && r.Header.Get("X-Forwarded-Proto") != "https") { + w.WriteHeader(http.StatusForbidden) + return + } + handler(w, r) + } +} + +// CreateFormFile is like multipart.Writer.CreateFormFile, but allows the +// setting of Content-Type. +func CreateFormFile(w *multipart.Writer, fieldname, filename, contentType string) (io.Writer, error) { + eq := utilsx.EscapeQuotes + h := make(textproto.MIMEHeader) + if contentType == "" { + contentType = "application/octet-stream" + } + + h.Set("Content-Type", contentType) + h.Set("Content-Disposition", + fmt.Sprintf( + `form-data; name="%s"; filename="%s"`, + eq(fieldname), eq(filename)), + ) + + return w.CreatePart(h) +} + +// ReadRequestOneFile reads the first file from the request (if multipart/), +// or returns the body if not +func ReadRequestOneFile(r *http.Request) (body io.ReadCloser, contentType string, status int, err error) { + body = r.Body + contentType = r.Header.Get("Content-Type") + + if !strings.HasPrefix(contentType, "multipart/") { + // not multipart-form + status = http.StatusOK + return + } + defer r.Body.Close() + + err = r.ParseMultipartForm(1 << 20) + if err != nil { + status, err = http.StatusMethodNotAllowed, errors.New("error parsing request as multipart-form: "+err.Error()) + return + } + + if r.MultipartForm == nil || len(r.MultipartForm.File) == 0 { + status, err = http.StatusMethodNotAllowed, errors.New("no files?") + return + } + +Outer: + for _, fileHeaders := range r.MultipartForm.File { + for _, fileHeader := range fileHeaders { + if body, err = fileHeader.Open(); err != nil { + status, err = + http.StatusMethodNotAllowed, + fmt.Errorf( + "error opening part %q: %s", fileHeader.Filename, err, + ) + return + } + contentType = fileHeader.Header.Get("Content-Type") + break Outer + } + } + + status = http.StatusOK + return +} + +// // ReadRequestFiles reads the files from the request, and calls ReaderToFile on them +// func ReadRequestFiles(r *http.Request) (filenames []string, status int, err error) { +// defer r.Body.Close() +// err = r.ParseMultipartForm(1 << 20) +// if err != nil { +// status, err = +// http.StatusMethodNotAllowed, +// errors.New("cannot parse request as multipart-form: "+err.Error()) +// return +// } +// if r.MultipartForm == nil || len(r.MultipartForm.File) == 0 { +// status, err = http.StatusMethodNotAllowed, errors.New("no files?") +// return +// } + +// filenames = make([]string, 0, len(r.MultipartForm.File)) +// var f multipart.File +// var fn string +// for _, fileHeaders := range r.MultipartForm.File { +// for _, fh := range fileHeaders { +// if f, err = fh.Open(); err != nil { +// status, err = +// http.StatusMethodNotAllowed, +// fmt.Errorf("error reading part %q: %s", fh.Filename, err) +// return +// } + +// if fn, err = temp.ReaderToFile(f, fh.Filename, ""); err != nil { +// f.Close() +// status, err = +// http.StatusInternalServerError, +// fmt.Errorf("error saving %q: %s", fh.Filename, err) +// return +// } +// f.Close() +// filenames = append(filenames, fn) +// } +// } +// if len(filenames) == 0 { +// status, err = http.StatusMethodNotAllowed, errors.New("no files??") +// return +// } + +// status = http.StatusOK +// return +// } + +// SendFile sends the given file as response +func SendFile(w http.ResponseWriter, filename, contentType string) error { + fh, err := os.Open(filename) + if err != nil { + return err + } + defer fh.Close() + + fi, err := fh.Stat() + if err != nil { + return err + } + + size := fi.Size() + if _, err = fh.Seek(0, 0); err != nil { + err = fmt.Errorf("error seeking in %v: %s", fh, err) + http.Error(w, err.Error(), 500) + return err + } + if contentType != "" { + w.Header().Add("Content-Type", contentType) + } + + w.Header().Add("Content-Length", fmt.Sprintf("%d", size)) + w.WriteHeader(200) + fh.Seek(0, 0) + + if _, err = io.CopyN(w, fh, size); err != nil { + err = fmt.Errorf("error sending file %q: %s", filename, err) + } + + return err +} + +// Fetch the contents of an HTTP URL. +// +// This is not intended to cover all possible use cases for fetching files, +// only the most common ones. Use the net/http package for more advanced usage. +func Fetch(url string) ([]byte, error) { + client := http.Client{Timeout: 60 * time.Second} + response, err := client.Get(url) + if err != nil { + return nil, errors.Wrapf(err, "cannot download %v", url) + } + defer response.Body.Close() // nolint: errcheck + + // TODO: Maybe add sanity check to bail out of the Content-Length is very + // large? + data, err := ioutil.ReadAll(response.Body) + if err != nil { + return nil, errors.Wrapf(err, "cannot read body of %v", url) + } + + if response.StatusCode != http.StatusOK { + return data, ErrNotOK{ + URL: url, + Err: fmt.Sprintf("%v %v", response.StatusCode, response.Status), + } + } + + return data, nil +} + +// Save an HTTP URL to the directory dir with the filename. The filename can be +// generated from the URL if empty. +// +// It will return the full path to the save file. Note that it may create both a +// file *and* return an error (e.g. in cases of non-200 status codes). +// +// This is not intended to cover all possible use cases for fetching files, +// only the most common ones. Use the net/http package for more advanced usage. +func Save(url string, dir string, filename string) (string, error) { + // Use last path of url if filename is empty + if filename == "" { + tokens := strings.Split(url, "/") + filename = tokens[len(tokens)-1] + } + path := filepath.FromSlash(dir + "/" + filename) + + client := http.Client{Timeout: 60 * time.Second} + response, err := client.Get(url) + if err != nil { + return "", errors.Wrapf(err, "cannot download %v", url) + } + defer response.Body.Close() // nolint: errcheck + + output, err := os.Create(path) + if err != nil { + return "", errors.Wrapf(err, "cannot create %v", path) + } + defer output.Close() // nolint: errcheck + + _, err = io.Copy(output, response.Body) + if err != nil { + return path, errors.Wrapf(err, "cannot read body of %v in to %v", url, path) + } + + if response.StatusCode != http.StatusOK { + return path, ErrNotOK{ + URL: url, + Err: fmt.Sprintf("%v %v", response.StatusCode, response.Status), + } + } + + return path, nil +} + +// DumpBody reads the body of a HTTP request without consuming it, so it can be +// read again later. +// It will read at most maxSize of bytes. Use -1 to read everything. +// +// It's based on httputil.DumpRequest. +func DumpBody(r *http.Request, maxSize int64) ([]byte, error) { + if r.Body == nil { + return nil, nil + } + + save, body, err := ioutilx.DumpReader(r.Body) + if err != nil { + return nil, err + } + + var b bytes.Buffer + var dest io.Writer = &b + + chunked := len(r.TransferEncoding) > 0 && r.TransferEncoding[0] == "chunked" + if chunked { + dest = httputil.NewChunkedWriter(dest) + } + + if maxSize < 0 { + _, err = io.Copy(dest, body) + } else { + _, err = io.CopyN(dest, body, maxSize) + if err == io.EOF { + err = nil + } + } + if err != nil { + return nil, err + } + if chunked { + _ = dest.(io.Closer).Close() + _, _ = io.WriteString(&b, "\r\n") + } + + r.Body = save + return b.Bytes(), nil +} + +// Header Utils ---------------------------------------------------------------- + +// Octet types from RFC 2616. +var octetTypes [256]octetType + +type octetType byte + +const ( + isToken octetType = 1 << iota + isSpace +) + +func init() { + // OCTET = + // CHAR = + // CTL = + // CR = + // LF = + // SP = + // HT = + // <"> = + // CRLF = CR LF + // LWS = [CRLF] 1*( SP | HT ) + // TEXT = + // separators = "(" | ")" | "<" | ">" | "@" | "," | ";" | ":" | "\" | <"> + // | "/" | "[" | "]" | "?" | "=" | "{" | "}" | SP | HT + // token = 1* + // qdtext = > + + for c := 0; c < 256; c++ { + var t octetType + isCtl := c <= 31 || c == 127 + isChar := 0 <= c && c <= 127 + isSeparator := strings.ContainsRune(" \t\"(),/:;<=>?@[]\\{}", rune(c)) + if strings.ContainsRune(" \t\r\n", rune(c)) { + t |= isSpace + } + if isChar && !isCtl && !isSeparator { + t |= isToken + } + octetTypes[c] = t + } +} + +// Copy returns a shallow copy of the header. +func Copy(header http.Header) http.Header { + h := make(http.Header) + for k, vs := range header { + h[k] = vs + } + return h +} + +var timeLayouts = []string{"Mon, 02 Jan 2006 15:04:05 GMT", time.RFC850, time.ANSIC} + +// ParseTime parses the header as time. The zero value is returned if the +// header is not present or there is an error parsing the +// header. +func ParseTime(header http.Header, key string) time.Time { + if s := header.Get(key); s != "" { + for _, layout := range timeLayouts { + if t, err := time.Parse(layout, s); err == nil { + return t.UTC() + } + } + } + return time.Time{} +} + +// ParseList parses a comma separated list of values. Commas are ignored in +// quoted strings. Quoted values are not unescaped or unquoted. Whitespace is +// trimmed. +func ParseList(header http.Header, key string) []string { + var result []string + for _, s := range header[http.CanonicalHeaderKey(key)] { + begin := 0 + end := 0 + escape := false + quote := false + for i := 0; i < len(s); i++ { + b := s[i] + switch { + case escape: + escape = false + end = i + 1 + case quote: + switch b { + case '\\': + escape = true + case '"': + quote = false + } + end = i + 1 + case b == '"': + quote = true + end = i + 1 + case octetTypes[b]&isSpace != 0: + if begin == end { + begin = i + 1 + end = begin + } + case b == ',': + if begin < end { + result = append(result, s[begin:end]) + } + begin = i + 1 + end = begin + default: + end = i + 1 + } + } + if begin < end { + result = append(result, s[begin:end]) + } + } + return result +} + +// ParseValueAndParams parses a comma separated list of values with optional +// semicolon separated name-value pairs. Content-Type and Content-Disposition +// headers are in this format. +func ParseValueAndParams(header http.Header, key string) (value string, params map[string]string) { + params = make(map[string]string) + s := header.Get(key) + value, s = expectTokenSlash(s) + if value == "" { + return + } + value = strings.ToLower(value) + s = skipSpace(s) + for strings.HasPrefix(s, ";") { + var pkey string + pkey, s = expectToken(skipSpace(s[1:])) + if pkey == "" { + return + } + if !strings.HasPrefix(s, "=") { + return + } + var pvalue string + pvalue, s = expectTokenOrQuoted(s[1:]) + if pvalue == "" { + return + } + pkey = strings.ToLower(pkey) + params[pkey] = pvalue + s = skipSpace(s) + } + return +} + +// AcceptSpec describes an Accept* header. +type AcceptSpec struct { + Value string + Q float64 +} + +// ParseAccept parses Accept* headers. +func ParseAccept(header http.Header, key string) (specs []AcceptSpec) { +loop: + for _, s := range header[key] { + for { + var spec AcceptSpec + spec.Value, s = expectTokenSlash(s) + if spec.Value == "" { + continue loop + } + spec.Q = 1.0 + s = skipSpace(s) + if strings.HasPrefix(s, ";") { + s = skipSpace(s[1:]) + if !strings.HasPrefix(s, "q=") { + continue loop + } + spec.Q, s = expectQuality(s[2:]) + if spec.Q < 0.0 { + continue loop + } + } + specs = append(specs, spec) + s = skipSpace(s) + if !strings.HasPrefix(s, ",") { + continue loop + } + s = skipSpace(s[1:]) + } + } + return +} + +func skipSpace(s string) (rest string) { + i := 0 + for ; i < len(s); i++ { + if octetTypes[s[i]]&isSpace == 0 { + break + } + } + return s[i:] +} + +func expectToken(s string) (token, rest string) { + i := 0 + for ; i < len(s); i++ { + if octetTypes[s[i]]&isToken == 0 { + break + } + } + return s[:i], s[i:] +} + +func expectTokenSlash(s string) (token, rest string) { + i := 0 + for ; i < len(s); i++ { + b := s[i] + if (octetTypes[b]&isToken == 0) && b != '/' { + break + } + } + return s[:i], s[i:] +} + +func expectQuality(s string) (q float64, rest string) { + switch { + case len(s) == 0: + return -1, "" + case s[0] == '0': + q = 0 + case s[0] == '1': + q = 1 + default: + return -1, "" + } + s = s[1:] + if !strings.HasPrefix(s, ".") { + return q, s + } + s = s[1:] + i := 0 + n := 0 + d := 1 + for ; i < len(s); i++ { + b := s[i] + if b < '0' || b > '9' { + break + } + n = n*10 + int(b) - '0' + d *= 10 + } + return q + float64(n)/float64(d), s[i:] +} + +func expectTokenOrQuoted(s string) (value, rest string) { + pkey, s := expectToken(s) + if pkey == "" { + return + } + if !strings.HasPrefix(s, "\"") { + return "", s + } + + s = s[1:] + for i := 0; i < len(s); i++ { + switch s[i] { + case '"': + return s[:i], s[i+1:] + case '\\': + p := make([]byte, len(s)-1) + j := copy(p, s[:i]) + escape := true + for i = i + 1; i < len(s); i++ { + b := s[i] + switch { + case escape: + escape = false + p[j] = b + j++ + case b == '\\': + escape = true + case b == '"': + return string(p[:j]), s[i+1:] + default: + p[j] = b + j++ + } + } + return "", "" + } + } + return "", "" +} + +// Set Utils ------------------------------------------------------------------- + +// Constants for DispositionArgs. +const ( + TypeInline = "inline" + TypeAttachment = "attachment" +) + +// DispositionArgs are arguments for SetContentDisposition(). +type DispositionArgs struct { + Type string // disposition-type + Filename string // filename-parm + //CreationDate time.Time // creation-date-parm + //ModificationDate time.Time // modification-date-parm + //ReadDate time.Time // read-date-parm + //Size int // size-parm +} + +// SetContentDisposition sets the Content-Disposition header. Any previous value +// will be overwritten. +// +// https://tools.ietf.org/html/rfc2183 +// https://tools.ietf.org/html/rfc6266 +// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Disposition +func SetContentDisposition(header http.Header, args DispositionArgs) error { + if header == nil { + return errors.New("header is nil map") + } + + if args.Type == "" { + return errors.New("the Type field is mandatory") + } + if args.Type != TypeInline && args.Type != TypeAttachment { + return fmt.Errorf("the Type field must be %#v or %#v", TypeInline, TypeAttachment) + } + v := args.Type + + if args.Filename != "" { + // Format filename= according to as defined in RFC822. + // We don't don't allow \, and % though. Replacing \ is a slightly lazy + // way to prevent certain injections in case of user-provided strings + // (ending the quoting and injecting their own values or even headers). + // % because some user agents interpret percent-encodings, and others do + // not (according to the RFC anyway). Finally escape " with \". + r := strings.NewReplacer("\\", "", "%", "", `"`, `\"`) + args.Filename = r.Replace(args.Filename) + + // Don't allow unicode. + ascii, hasUni := hasUnicode(args.Filename) + v += fmt.Sprintf(`; filename="%v"`, ascii) + + // Add filename* for unicode, encoded according to + // https://tools.ietf.org/html/rfc5987 + if hasUni { + v += fmt.Sprintf("; filename*=UTF-8''%v", + url.QueryEscape(args.Filename)) + } + } + + header.Set("Content-Disposition", v) + return nil +} + +func hasUnicode(s string) (string, bool) { + i := 0 + has := false + deuni := make([]rune, len(s)) + for _, c := range s { + // TODO: maybe also disallow any escape chars? + switch { + case c > 255: + has = true + default: + deuni[i] = c + i++ + } + } + + return strings.TrimRight(string(deuni), "\x00"), has +} + +// ----------------------------------------------------------------------------- diff --git a/httpx/httpx_test.go b/httpx/httpx_test.go new file mode 100644 index 0000000..edc0c95 --- /dev/null +++ b/httpx/httpx_test.go @@ -0,0 +1,43 @@ +package httpx + +import ( + "net/http" + "net/http/httptest" + "os" + "testing" +) + +func TestEnsureHTTPS(t *testing.T) { + handler := func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + } + + regularRequest, err := http.NewRequest("GET", "http://example.com/foo", nil) + if err != nil { + t.Fatal(err) + } + + sslRequest, err := http.NewRequest("GET", "https://example.com/foo", nil) + if err != nil { + t.Fatal(err) + } + + w := httptest.NewRecorder() + EnsureHTTPS(handler)(w, regularRequest) + if w.Code == 200 { + t.Fatalf("Expected failure since scheme was http: got=%d", w.Code) + } + + w = httptest.NewRecorder() + EnsureHTTPS(handler)(w, sslRequest) + if w.Code != 200 { + t.Fatalf("Expected success since scheme was https: got=%d", w.Code) + } + + os.Setenv("DISABLE_ENSURE_HTTPS", "1") + w = httptest.NewRecorder() + EnsureHTTPS(handler)(w, regularRequest) + if w.Code != 200 { + t.Fatalf("Expected success since we've disabled required HTTPS: got=%d", w.Code) + } +} diff --git a/ioutilx/ioutilx.go b/ioutilx/ioutilx.go new file mode 100644 index 0000000..563da6e --- /dev/null +++ b/ioutilx/ioutilx.go @@ -0,0 +1,37 @@ +package ioutilx + +import ( + "bytes" + "io" + "io/ioutil" + "net/http" +) + +// DumpReader reads all of b to memory and then returns two equivalent +// ReadClosers which will yield the same bytes. +// +// This is useful if you want to read data from an io.Reader more than once. +// +// It returns an error if the initial reading of all bytes fails. It does not +// attempt to make the returned ReadClosers have identical error-matching +// behavior. +// +// This is based on httputil.DumpRequest: +// Ref: github.com/teamwork/ioutilx.DumpBody(). +func DumpReader(b io.ReadCloser) (r1, r2 io.ReadCloser, err error) { + if b == http.NoBody { + // No copying needed. Preserve the magic sentinel meaning of NoBody. + return http.NoBody, http.NoBody, nil + } + + var buf bytes.Buffer + if _, err = buf.ReadFrom(b); err != nil { + return nil, b, err + } + + if err = b.Close(); err != nil { + return nil, b, err + } + + return ioutil.NopCloser(&buf), ioutil.NopCloser(bytes.NewReader(buf.Bytes())), nil +} diff --git a/jsonx/constants.go b/jsonx/constants.go new file mode 100644 index 0000000..aefee28 --- /dev/null +++ b/jsonx/constants.go @@ -0,0 +1,6 @@ +package gutils + +const ( + // RegexJsonpToJSON represents regular expression to convert JSONP to JSON. + RegexJsonpToJSON = "([^\\s\\:\\{\\,\\d\"]+|[a-z][a-z\\d]*)\\s*\\:" +) diff --git a/jsonx/jsonx.go b/jsonx/jsonx.go new file mode 100644 index 0000000..492a8e8 --- /dev/null +++ b/jsonx/jsonx.go @@ -0,0 +1,36 @@ +package gutils + +import ( + "encoding/json" + "fmt" + "regexp" + "strings" +) + +// JSONString should convert objects to JSON strings. +func JSONString(obj interface{}) (r string) { + b, _ := json.Marshal(obj) + s := fmt.Sprintf("%+v", string(b)) + r = strings.Replace(s, `\u003c`, "<", -1) + r = strings.Replace(r, `\u003e`, ">", -1) + return +} + +// JsonpToJSON should modify JSONP string to json string. +// Usecase: JsonpToJson({a:1,b:2}) -> {"a":1,"b":2} +// Ref: https://stackoverflow.com/a/3840118/4039768 (What is JSONP?) +func JsonpToJSON(s string) string { + start := strings.Index(s, "{") + end := strings.LastIndex(s, "}") + nextStart := strings.Index(s, "[") + if nextStart > 0 && start > nextStart { + start = nextStart + end = strings.LastIndex(s, "]") + } + if end > start && end != -1 && start != -1 { + s = s[start : end+1] + } + s = strings.Replace(s, "\\'", "", -1) + regexp, _ := regexp.Compile(RegexJsonpToJSON) + return regexp.ReplaceAllString(s, "\"$1\":") +} diff --git a/logx/logx.go b/logx/logx.go new file mode 100644 index 0000000..9f17502 --- /dev/null +++ b/logx/logx.go @@ -0,0 +1,190 @@ +// Package logx is simple extended version of standard 'log' package based on +// logLevel. Most of the concepts inspired from https://godoc.org/github.com/golang/glog +// and https://github.com/goinggo/tracelog. These packages are huge and complex. +// Hence we writing our own log package with as simple as possible. +// Another ref: https://github.com/UlricQin/goutils/blob/master/logtool/logtool.go +package logx + +import ( + "fmt" + "io" + "io/ioutil" + "log" + "os" +) + +// logLevel is a severity level at which logger works. +type logLevel int + +const ( + // InfoLevel ... + InfoLevel logLevel = iota + // WarningLevel ... + WarningLevel + // ErrorLevel ... + ErrorLevel + // FatalLevel ... + FatalLevel +) + +var ( + infoLogger *log.Logger + warningLogger *log.Logger + errorLogger *log.Logger + fatalLogger *log.Logger + + // Levels ... + Levels = map[string]logLevel{ + "INFO": InfoLevel, + "WARNING": WarningLevel, + "ERROR": ErrorLevel, + "FATAL": FatalLevel, + } +) + +func init() { + Init(InfoLevel, nil) +} + +// Init ... +func Init(lev logLevel, multiHandler io.Writer) { + infoHandler := ioutil.Discard + warningHandler := ioutil.Discard + errorHandler := ioutil.Discard + fatalHandler := ioutil.Discard + + switch lev { + case InfoLevel: + infoHandler = os.Stdout + warningHandler = os.Stdout + errorHandler = os.Stderr + fatalHandler = os.Stderr + case WarningLevel: + warningHandler = os.Stdout + errorHandler = os.Stderr + fatalHandler = os.Stderr + case ErrorLevel: + errorHandler = os.Stderr + fatalHandler = os.Stderr + case FatalLevel: + fatalHandler = os.Stderr + default: + log.Fatal("logx: Invalid log level should be (0-3)") + } + + if multiHandler != nil { + if infoHandler == os.Stdout { + infoHandler = io.MultiWriter(infoHandler, multiHandler) + } + if warningHandler == os.Stdout { + warningHandler = io.MultiWriter(warningHandler, multiHandler) + } + if errorHandler == os.Stderr { + errorHandler = io.MultiWriter(errorHandler, multiHandler) + } + if fatalHandler == os.Stderr { + fatalHandler = io.MultiWriter(fatalHandler, multiHandler) + } + } + + infoLogger = log.New(infoHandler, "INFO: ", log.Ldate|log.Ltime|log.Lshortfile) + warningLogger = log.New(warningHandler, "WARNING: ", log.Ldate|log.Ltime|log.Lshortfile) + errorLogger = log.New(errorHandler, "ERROR: ", log.Ldate|log.Ltime|log.Lshortfile) + fatalLogger = log.New(fatalHandler, "FATAL: ", log.Ldate|log.Ltime|log.Lshortfile) +} + +// ----------------------------------------------------------------------------- +// Wrapper for changing logger's output writer anytime. + +// InfoSetOutput ... +func InfoSetOutput(w io.Writer) { + infoLogger.SetOutput(w) +} + +// WarningSetOutput ... +func WarningSetOutput(w io.Writer) { + warningLogger.SetOutput(w) +} + +// ErrorSetOutput ... +func ErrorSetOutput(w io.Writer) { + errorLogger.SetOutput(w) +} + +// FatalSetOutput ... +func FatalSetOutput(w io.Writer) { + fatalLogger.SetOutput(w) +} + +// ----------------------------------------------------------------------------- + +// Info writes into infologger as same as basic fmt.Print. +func Info(v ...interface{}) { + infoLogger.Output(2, fmt.Sprint(v...)) +} + +// Infoln writes into infologger as same as basic fmt.Println. +func Infoln(v ...interface{}) { + infoLogger.Output(2, fmt.Sprintln(v...)) +} + +// Infof writes into infologger as same as basic Outputf. +func Infof(format string, v ...interface{}) { + infoLogger.Output(2, fmt.Sprintf(format, v...)) +} + +// Warning writes warning messages into warninglogger as same as basic +// log.Output. +func Warning(v ...interface{}) { + warningLogger.Output(2, fmt.Sprint(v...)) +} + +// Warningln writes warning messages into warninglogger as same as basic +// log.Outputln. +func Warningln(v ...interface{}) { + warningLogger.Output(2, fmt.Sprintln(v...)) +} + +// Warningf writes warning messages into warninglogger as same as basic +// log.Outputf. +func Warningf(format string, v ...interface{}) { + warningLogger.Output(2, fmt.Sprintf(format, v...)) +} + +// Error writes error messages into errorlogger as same as basic log.Error. +func Error(v ...interface{}) { + errorLogger.Output(2, fmt.Sprint(v...)) +} + +// Errorln writes error messages into errorlogger as same as basic log.Errorln. +func Errorln(v ...interface{}) { + errorLogger.Output(2, fmt.Sprintln(v...)) +} + +// Errorf writes error messages into errorlogger as same as basic log.Errorf. +func Errorf(format string, v ...interface{}) { + errorLogger.Output(2, fmt.Sprintf(format, v...)) +} + +// Fatal writes fatal error messages into errorlogger and exit as same as basic +// log.Fatal. +func Fatal(v ...interface{}) { + fatalLogger.Output(2, fmt.Sprint(v...)) + os.Exit(1) +} + +// Fatalln writes fatal error messages into errorlogger and exit as same as +// basic log.Fataln. +func Fatalln(v ...interface{}) { + fatalLogger.Output(2, fmt.Sprintln(v...)) + os.Exit(1) +} + +// Fatalf writes fatal error messages into errorlogger and exit as same as basic +// log.Fataf. +func Fatalf(format string, v ...interface{}) { + fatalLogger.Output(2, fmt.Sprintf(format, v...)) + os.Exit(1) +} + +// ----------------------------------------------------------------------------- diff --git a/logx/logx_test.go b/logx/logx_test.go new file mode 100644 index 0000000..e59cfd7 --- /dev/null +++ b/logx/logx_test.go @@ -0,0 +1,25 @@ +package logx + +import "testing" + +func TestInitWithoutMultiHandle(t *testing.T) { + // TODO: write test for logLevel check. If FatalLevel is set other handlers should have iouti.Discard writer + // currently noway to get the writer for any logger via standard log package. + + Init(InfoLevel, nil) + if infoLogger.Prefix() != "INFO: " { + t.Error("Error in infologger initialize") + } + + if warningLogger.Prefix() != "WARNING: " { + t.Error("Error in warninglogger initialize") + } + + if errorLogger.Prefix() != "ERROR: " { + t.Error("Error in errorlogger initialize") + } + + if fatalLogger.Prefix() != "FATAL: " { + t.Error("Error in errorlogger initialize") + } +} diff --git a/main.go b/main.go new file mode 100644 index 0000000..2cb499b --- /dev/null +++ b/main.go @@ -0,0 +1,10 @@ +// Package main contains common and useful utils for the Go project development. +// +// Inclusion criteria: +// - Only rely on the Go standard package +// - Functions or lightweight packages +// - Non-business related general tools +package main + +func main() { +} diff --git a/mapx/mapx.go b/mapx/mapx.go new file mode 100644 index 0000000..f1de93d --- /dev/null +++ b/mapx/mapx.go @@ -0,0 +1,10 @@ +package mapx + +// Reverse the keys and values of a map. +func Reverse(m map[string]string) map[string]string { + n := make(map[string]string) + for k, v := range m { + n[v] = k + } + return n +} diff --git a/mathx/mathx.go b/mathx/mathx.go new file mode 100644 index 0000000..dbfefce --- /dev/null +++ b/mathx/mathx.go @@ -0,0 +1,125 @@ +package mathx + +import ( + "fmt" + "math" +) + +// Byte is a float64 where the String() method prints out a human-redable description. +type Byte float64 + +var units = []string{"B", "KiB", "MiB", "GiB", "TiB", "PiB"} + +// Round will round the value to the nearest natural number. +// .5 will be rounded up. +func Round(f float64) float64 { + if f < 0 { + return math.Ceil(f - 0.5) + } + return math.Floor(f + 0.5) +} + +// RoundPlus will round the value to the given precision. +// e.g. RoundPlus(7.258, 2) will return 7.26 +func RoundPlus(f float64, precision int) float64 { + shift := math.Pow(10, float64(precision)) + return Round(f*shift) / shift +} + +// Min gets the lowest of two numbers. +func Min(a, b int64) int64 { + if a > b { + return b + } + return a +} + +// Max gets the highest of two numbers. +func Max(a, b int64) int64 { + if a < b { + return b + } + return a +} + +// MinMax should return min and max int value from given container. +func MinMax(container []int) (min int, max int) { + if len(container) == 0 { + return + } + + min = container[0] + max = container[0] + for _, value := range container { + if max < value { + max = value + } + if min > value { + min = value + } + } + + return +} + +// NextMin should return min int value from given container. +func NextMin(v []int, cur int) (int, bool) { + minElems := []int{} + for _, e := range v { + if e < cur { + minElems = append(minElems, e) + } + } + + if len(minElems) == 0 { + return cur, true + } + + _, max := MinMax(minElems) + return max, false +} + +// NextMax should return max int value from given container. +func NextMax(v []int, cur int) (int, bool) { + maxElems := []int{} + for _, e := range v { + if e > cur { + maxElems = append(maxElems, e) + } + } + + if len(maxElems) == 0 { + return cur, true + } + + min, _ := MinMax(maxElems) + return min, false +} + +// Limit a value between a lower and upper limit. +func Limit(v, lower, upper float64) float64 { + return math.Max(math.Min(v, upper), lower) +} + +// DivideCeil divides two integers and rounds up, rather than down (which is +// what happens when you do int64/int64). +func DivideCeil(count int64, pageSize int64) int64 { + return int64(math.Ceil(float64(count) / float64(pageSize))) +} + +// IsSignedZero checks if this number is a signed zero (i.e. -0, instead of +0). +func IsSignedZero(f float64) bool { + return math.Float64bits(f)^uint64(1<<63) == 0 +} + +// String is the string representation of byte. +func (b Byte) String() string { + i := 0 + for ; i < len(units); i++ { + if b < 1024 { + return fmt.Sprintf("%.1f%s", b, units[i]) + } + b /= 1024 + } + return fmt.Sprintf("%.1f%s", b*1024, units[i-1]) +} diff --git a/netx/netx.go b/netx/netx.go new file mode 100644 index 0000000..d61f776 --- /dev/null +++ b/netx/netx.go @@ -0,0 +1,466 @@ +package netx + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io/ioutil" + "net" + "net/http" + "net/url" + "os/exec" + "strconv" + "strings" + + "github.com/pkg/errors" + "golang.org/x/net/context/ctxhttp" +) + +const gmapsURL = `https://maps.googleapis.com/maps/api/geocode/json?sensors=false&address={{.Address}}` + +var ( + // ErrNotFound ... + ErrNotFound = errors.New("not found") + // ErrTooManyResults ... + ErrTooManyResults = errors.New("too many results") +) + +type ( + // Coordinates ... + Coordinates struct { + Latitude float64 `json:"latitude"` + Longitude float64 `json:"longitude"` + } + + // Location ... + Location struct { + Address string + Coordinates + } + + // MapsLocation ... + MapsLocation struct { + Coordinates + } + + // MapsGeometry ... + MapsGeometry struct { + Location MapsLocation `json:"location"` + } + + // MapsResult ... + MapsResult struct { + FormattedAddress string `json:"formatted_address"` + Geometry MapsGeometry `json:"geometry"` + } + + // MapsResponse ... + MapsResponse struct { + Status string `json:"status"` + Results []MapsResult `json:"results"` + } + + // Net ... + Net struct { + Address string + Bitmask uint8 + Mask string + Hostmask string + Broadcast string + First string + Last string + Size uint32 + } +) + +// Get ... +func Get(ctx context.Context, address string) (Location, error) { + var ( + loc Location + data MapsResponse + ) + + select { + case <-ctx.Done(): + return loc, ctx.Err() + default: + } + + aURL := strings.Replace(gmapsURL, "{{.Address}}", url.QueryEscape(address), 1) + resp, err := ctxhttp.Get(ctx, nil, aURL) + if err != nil { + return loc, errors.Wrapf(err, aURL) + } + defer resp.Body.Close() + if resp.StatusCode > 299 { + return loc, errors.Wrapf(err, aURL) + } + + if err = json.NewDecoder(resp.Body).Decode(&data); err != nil { + return loc, errors.Wrapf(err, "decode") + } + + switch data.Status { + case "OK": + case "ZERO_RESULTS": + return loc, ErrNotFound + default: + return loc, errors.Wrapf(err, "status=%q", data.Status) + } + + switch len(data.Results) { + case 0: + return loc, ErrNotFound + case 1: + default: + return loc, ErrTooManyResults + } + + result := data.Results[0] + loc.Address = result.FormattedAddress + loc.Coordinates.Latitude, loc.Coordinates.Longitude = + result.Geometry.Location.Coordinates.Latitude, + result.Geometry.Location.Coordinates.Longitude + + return loc, nil +} + +// RemovePort removes the "port" part of an hostname. +func RemovePort(host string) string { + shost, _, err := net.SplitHostPort(host) + // Probably doesn't have a port, which is an error. + if err != nil { + return host + } + return shost +} + +// LocalIP should return local IP address. +func LocalIP() (string, error) { + addr, err := net.ResolveUDPAddr("udp", "1.2.3.4:1") + if err != nil { + return "", err + } + + conn, err := net.DialUDP("udp", nil, addr) + if err != nil { + return "", err + } + + defer conn.Close() + + host, _, err := net.SplitHostPort(conn.LocalAddr().String()) + if err != nil { + return "", err + } + + // host = "10.180.2.66" + return host, nil +} + +// LocalDNSName should return host name. +func LocalDNSName() (hostname string, err error) { + var ip string + ip, err = LocalIP() + if err != nil { + return + } + + cmd := exec.Command("host", ip) + var out bytes.Buffer + cmd.Stdout = &out + err = cmd.Run() + if err != nil { + return + } + + tmp := out.String() + arr := strings.Split(tmp, ".\n") + + if len(arr) > 1 { + content := arr[0] + arr = strings.Split(content, " ") + return arr[len(arr)-1], nil + } + + err = fmt.Errorf("parse host %s fail", ip) + return +} + +// IntranetIP get internal IP addr. +func IntranetIP() (string, error) { + ifaces, err := net.Interfaces() + if err != nil { + return "", err + } + + for _, iface := range ifaces { + if iface.Flags&net.FlagUp == 0 { + continue // Interface down. + } + + if iface.Flags&net.FlagLoopback != 0 { + continue // Loopback interface. + } + + if strings.HasPrefix(iface.Name, "docker") || + strings.HasPrefix(iface.Name, "w-") { + continue + } + + addrs, err := iface.Addrs() + if err != nil { + return "", err + } + + for _, addr := range addrs { + var ip net.IP + switch v := addr.(type) { + case *net.IPNet: + ip = v.IP + case *net.IPAddr: + ip = v.IP + } + + if ip == nil || ip.IsLoopback() { + continue + } + + ip = ip.To4() + if ip == nil { + continue // Not an IPv4 address. + } + + return ip.String(), nil + } + } + + return "", errors.New("Are you connected to the network?") +} + +// ExtranetIP should return external IP address. +func ExtranetIP() (ip string, err error) { + defer func() { + if p := recover(); p != nil { + err = fmt.Errorf("Get external IP error: %v", p) + } else if err != nil { + err = errors.New("Get external IP error: " + err.Error()) + } + }() + + resp, err := http.Get("http://pv.sohu.com/cityjson?ie=utf-8") + if err != nil { + return + } + + b, err := ioutil.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + return + } + + idx := bytes.Index(b, []byte(`"cip": "`)) + b = b[idx+len(`"cip": "`):] + idx = bytes.Index(b, []byte(`"`)) + b = b[:idx] + ip = string(b) + + return +} + +// Atoi returns the uint32 representation of an ipv4 addr string value. +// +// Example: +// +// Atoi("192.168.0.1") // 3232235521 +// +func Atoi(addr string) (sum uint32, err error) { + if len(addr) > 15 { + return sum, errors.New("addr too long") + } + + octs := strings.Split(addr, ".") + if len(octs) != 4 { + return sum, errors.New("requires 4 octects") + } + + for i := 0; i < 4; i++ { + oct, err := strconv.ParseUint(octs[i], 10, 0) + if err != nil { + return sum, errors.New("bad octect " + octs[i]) + } + sum += uint32(oct) << uint32((4-1-i)*8) + } + return sum, nil +} + +// Itoa returns the string representation of an ipv4 addr uint32 value. +// +// Example: +// +// Itoa(3232235521) // "192.168.0.1" +// +func Itoa(addr uint32) string { + var buf bytes.Buffer + + for i := 0; i < 4; i++ { + oct := (addr >> uint32((4-1-i)*8)) & 0xff + buf.WriteString(strconv.FormatUint(uint64(oct), 10)) + if i < 3 { + buf.WriteByte('.') + } + } + return buf.String() +} + +// Not ... +// Example: +// +// Not("0.0.255.255") // "255.255.0.0" +// +func Not(addr string) (string, error) { + i, err := Atoi(addr) + return Itoa(i ^ 0xffffffff), err +} + +// Or ... +// Example: +// +// Or("0.0.1.1", "1.1.0.0") // "1.1.1.1" +// +func Or(addra string, addrb string) (addr string, err error) { + ia, err := Atoi(addra) + if err != nil { + return addr, err + } + + ib, err := Atoi(addrb) + if err != nil { + return addr, err + } + + return Itoa(ia | ib), err +} + +// Xor ... +// Example: +// +// Xor("0.255.255.255", "192.255.255.255") // "192.0.0.0" +// +func Xor(addra string, addrb string) (addr string, err error) { + ia, err := Atoi(addra) + if err != nil { + return addr, err + } + + ib, err := Atoi(addrb) + if err != nil { + return addr, err + } + + return Itoa(ia ^ ib), err +} + +// Next ... +// Example: +// +// Next("192.168.0.1") // "192.168.0.2" +// +func Next(addr string) (string, error) { + i, err := Atoi(addr) + return Itoa(i + 1), err +} + +// Prev ... +// Example: +// +// Prev("192.168.0.1") // "192.168.0.0" +// +func Prev(addr string) (string, error) { + i, err := Atoi(addr) + return Itoa(i - 1), err +} + +// Network returns information for a netblock. +// +// Example: +// +// Network("192.168.0.0/24") +// // { +// // Address: "192.168.0.0", +// // Bitmask: 24, +// // Mask: "255.255.255.0", +// // Hostmask: "0.0.0.255", +// // Broadcast: "192.168.0.255", +// // First: "192.168.0.1", +// // Last: "192.168.0.254", +// // Size: 254, +// // } +func Network(block string) (net Net, err error) { + if len(block) > 18 { + return net, errors.New("block too long") + } + + list := strings.Split(block, "/") + if len(list) != 2 { + return net, errors.New("invalid block") + } + + // address + net.Address = list[0] + + // bitmask + bitmask, err := strconv.ParseUint(list[1], 10, 0) + if err != nil { + return net, err + } + if bitmask&31 != bitmask { + return net, errors.New("invalid bitmask") + } + net.Bitmask = uint8(bitmask) + + // mask + net.Mask = Itoa(0xffffffff >> (32 - net.Bitmask) << (32 - net.Bitmask)) + net.Hostmask, err = Not(net.Mask) + if err != nil { + return net, err + } + + // broadcast + net.Broadcast, err = Or(net.Address, net.Hostmask) + if err != nil { + return net, err + } + + // first + addr, err := Xor(net.Hostmask, net.Broadcast) + if err != nil { + return net, err + } + + net.First, err = Next(addr) + if err != nil { + return net, err + } + + // last + net.Last, err = Prev(net.Broadcast) + if err != nil { + return net, err + } + + // size + i, err := Atoi(net.Last) + if err != nil { + return net, err + } + + j, err := Atoi(net.First) + if err != nil { + return net, err + } + + net.Size = i - j + 1 + return net, nil +} diff --git a/netx/netx_test.go b/netx/netx_test.go new file mode 100644 index 0000000..8276964 --- /dev/null +++ b/netx/netx_test.go @@ -0,0 +1,206 @@ +package netx + +import ( + "testing" +) + +var RemovePortCases = []struct { + in, expected string +}{ + {"127.0.0.1:2345", "127.0.0.1"}, + {"127.0.0.1", "127.0.0.1"}, + {"127.0.0.1:", "127.0.0.1"}, + {"::1", "::1"}, + {"[::1]:80", "::1"}, + {"arp242.net:", "arp242.net"}, + {"arp242.net:", "arp242.net"}, + {"arp242.net:8080", "arp242.net"}, +} + +func TestRemovePort(t *testing.T) { + for _, tc := range RemovePortCases { + t.Run(tc.in, func(t *testing.T) { + out := RemovePort(tc.in) + if out != tc.expected { + t.Errorf("\nout: %#v\nexpected: %#v\n", out, tc.expected) + } + }) + } +} + +var AtoiCases = []struct { + addr string + except uint32 +}{ + {"0.0.0.0", 0}, + {"0.0.1.0", 256}, + {"0.1.1.0", 256 + 256*256}, + {"1.1.1.0", 256 + 256*256 + 256*256*256}, + {"192.168.0.1", 3232235521}, +} + +func TestAtoi(t *testing.T) { + for _, testCase := range AtoiCases { + got, err := Atoi(testCase.addr) + if err != nil { + t.Errorf("unexcepted error %v", err) + } + if got != testCase.except { + t.Errorf("except=%v but got=%v", testCase.except, got) + } + } +} + +var ItoaCases = []struct { + addr uint32 + except string +}{ + {0, "0.0.0.0"}, + {256, "0.0.1.0"}, + {256 + 256*256, "0.1.1.0"}, + {256 + 256*256 + 256*256*256, "1.1.1.0"}, + {3232235521, "192.168.0.1"}, +} + +func TestItoa(t *testing.T) { + for _, testCase := range ItoaCases { + got := Itoa(testCase.addr) + if got != testCase.except { + t.Errorf("except=%v but got=%v", testCase.except, got) + } + } +} + +var NotCases = []struct { + addr string + except string +}{ + {"0.0.1.1", "255.255.254.254"}, + {"0.0.255.255", "255.255.0.0"}, +} + +func TestNot(t *testing.T) { + for _, testCase := range NotCases { + got, err := Not(testCase.addr) + if err != nil { + t.Errorf("unexcepted error %v", err) + } + if got != testCase.except { + t.Errorf("except=%v but got=%v", testCase.except, got) + } + } +} + +var OrCases = []struct { + addra string + addrb string + except string +}{ + {"0.0.1.1", "1.1.0.0", "1.1.1.1"}, + {"0.0.1.2", "1.2.0.0", "1.2.1.2"}, + {"0.0.1.233", "1.2.0.0", "1.2.1.233"}, + {"0.0.1.233", "1.2.0.2", "1.2.1.235"}, +} + +func TestOr(t *testing.T) { + for _, testCase := range OrCases { + got, err := Or(testCase.addra, testCase.addrb) + if err != nil { + t.Errorf("unexcepted error %v", err) + } + if got != testCase.except { + t.Errorf("except=%v but got=%v", testCase.except, got) + } + } +} + +var XorCases = []struct { + addra string + addrb string + except string +}{ + {"0.255.255.255", "192.255.255.255", "192.0.0.0"}, +} + +func TestXor(t *testing.T) { + for _, testCase := range XorCases { + got, err := Xor(testCase.addra, testCase.addrb) + if err != nil { + t.Errorf("unexcepted error %v", err) + } + if got != testCase.except { + t.Errorf("except=%v but got=%v", testCase.except, got) + } + } +} + +var PrevCases = []struct { + addr string + except string +}{ + {"0.0.0.1", "0.0.0.0"}, + {"0.0.255.255", "0.0.255.254"}, + {"0.0.0.0", "255.255.255.255"}, + {"192.168.0.1", "192.168.0.0"}, +} + +func TestPrev(t *testing.T) { + for _, testCase := range PrevCases { + got, err := Prev(testCase.addr) + if err != nil { + t.Errorf("unexcepted error %v", err) + } + if got != testCase.except { + t.Errorf("except=%v but got=%v", testCase.except, got) + } + } +} + +var NextCases = []struct { + addr string + except string +}{ + {"0.0.0.0", "0.0.0.1"}, + {"0.0.255.255", "0.1.0.0"}, + {"255.255.255.255", "0.0.0.0"}, + {"192.168.0.1", "192.168.0.2"}, +} + +func TestNext(t *testing.T) { + for _, testCase := range NextCases { + got, err := Next(testCase.addr) + if err != nil { + t.Errorf("unexcepted error %v", err) + } + if got != testCase.except { + t.Errorf("except=%v but got=%v", testCase.except, got) + } + } +} + +var NetworkCases = []struct { + block string + except Net +}{ + {"192.168.0.0/24", Net{ + "192.168.0.0", + 24, + "255.255.255.0", + "0.0.0.255", + "192.168.0.255", + "192.168.0.1", + "192.168.0.254", + 254}}, +} + +func TestNetwork(t *testing.T) { + for _, testCase := range NetworkCases { + got, err := Network(testCase.block) + if err != nil { + t.Errorf("unexcepted error %v", err) + } + if got != testCase.except { + t.Errorf("except=%v but got=%v", testCase.except, got) + } + } +} diff --git a/netx/rpcx/rpcx.go b/netx/rpcx/rpcx.go new file mode 100644 index 0000000..4abf2f2 --- /dev/null +++ b/netx/rpcx/rpcx.go @@ -0,0 +1,18 @@ +package rpcx + +import ( + "net" + "net/rpc" + "net/rpc/jsonrpc" + "time" +) + +// DialTimeout acts like Dial but takes a timeout. Returns a new rpc.Client to +// handle requests to the set of services at the other end of the connection. +func DialTimeout(network, address string, timeout time.Duration) (*rpc.Client, error) { + conn, err := net.DialTimeout(network, address, timeout) + if err != nil { + return nil, err + } + return jsonrpc.NewClient(conn), err +} diff --git a/osx/osx.go b/osx/osx.go new file mode 100644 index 0000000..f4e4169 --- /dev/null +++ b/osx/osx.go @@ -0,0 +1,165 @@ +package osx + +import ( + "bufio" + "bytes" + "errors" + "io/ioutil" + "log" + "os" + "os/exec" + "strconv" + "strings" + "sync" + "time" +) + +var ( + lastmod, lastcheck time.Time + groups map[int]string + groupsMu sync.RWMutex +) + +const groupFile = "/etc/group" + +// IntTimeout is the duration to wait before Kill after Int +var IntTimeout = 3 * time.Second + +// Log is discarded by default +var Log = func(keyvals ...interface{}) error { return nil } + +// ErrTimedOut is an error for child timeout +var ErrTimedOut = errors.New("child timed out") + +type gCmd struct { + *exec.Cmd + done chan error +} + +func (c *gCmd) Start() error { + if err := c.Cmd.Start(); err != nil { + return err + } + c.done = make(chan error, 1) + go func() { c.done <- c.Cmd.Wait() }() + return nil +} + +// RunWithTimeout runs cmd, and kills the child on timeout +func RunWithTimeout(timeoutSeconds int, cmd *exec.Cmd) error { + if cmd.SysProcAttr == nil { + procAttrSetGroup(cmd) + } + + gcmd := &gCmd{Cmd: cmd} + if err := gcmd.Start(); err != nil { + return err + } + if timeoutSeconds <= 0 { + return <-gcmd.done + } + + select { + case err := <-gcmd.done: + return err + case <-time.After(time.Second * time.Duration(timeoutSeconds)): + Log("msg", "killing timed out", "pid", cmd.Process.Pid, "path", cmd.Path, "args", cmd.Args) + if killErr := familyKill(gcmd.Cmd, true); killErr != nil { + Log("msg", "interrupt", "pid", cmd.Process.Pid) + } + + select { + case <-gcmd.done: + case <-time.After(IntTimeout): + familyKill(gcmd.Cmd, false) + } + } + + return ErrTimedOut +} + +// GroupName returns the name for the gid. +func GroupName(gid int) (string, error) { + groupsMu.RLock() + if groups == nil { + groupsMu.RUnlock() + groupsMu.Lock() + defer groupsMu.Unlock() + if groups != nil { // sy was faster + name := groups[gid] + return name, nil + } + } else { + now := time.Now() + if lastcheck.Add(1 * time.Second).After(now) { // fresh + name := groups[gid] + groupsMu.RUnlock() + return name, nil + } + + actcheck := lastcheck + groupsMu.RUnlock() + groupsMu.Lock() + defer groupsMu.Unlock() + if lastcheck != actcheck { // sy was faster + return groups[gid], nil + } + + fi, err := os.Stat(groupFile) + if err != nil { + return "", err + } + + lastcheck = now + if lastmod == fi.ModTime() { // no change + return groups[gid], nil + } + } + + // need to reread + if groups == nil { + groups = make(map[int]string, 64) + } else { + for k := range groups { + delete(groups, k) + } + } + + fh, err := os.Open(groupFile) + if err != nil { + return "", err + } + defer fh.Close() + + fi, err := fh.Stat() + if err != nil { + return "", err + } + + lastcheck = time.Now() + lastmod = fi.ModTime() + scanner := bufio.NewScanner(fh) + for scanner.Scan() { + parts := strings.SplitN(scanner.Text(), ":", 4) + id, err := strconv.Atoi(parts[2]) + if err != nil { + log.Printf("cannot parse %q as group id from line %q: %v", parts[2], scanner.Text(), err) + } + if old, ok := groups[id]; ok { + log.Printf("double entry %d: %q and %q?", id, old, parts[0]) + continue + } + groups[id] = parts[0] + } + + return groups[gid], nil +} + +// IsInsideDocker returns true iff we are inside a docker cgroup. +func IsInsideDocker() bool { + b, err := ioutil.ReadFile("/proc/self/cgroup") + if err != nil { + return false + } + return bytes.Contains(b, []byte(":/docker/")) || bytes.Contains(b, []byte(":/lxc/")) +} diff --git a/osx/proc_linux.go b/osx/proc_linux.go new file mode 100644 index 0000000..0b11562 --- /dev/null +++ b/osx/proc_linux.go @@ -0,0 +1,13 @@ +package osx + +import ( + "os/exec" + "syscall" +) + +func procAttrSetGroup(c *exec.Cmd) { + c.SysProcAttr = &syscall.SysProcAttr{ + Setpgid: true, // to be able to kill all children, too + Pdeathsig: syscall.SIGKILL, + } +} diff --git a/osx/proc_posix.go b/osx/proc_posix.go new file mode 100644 index 0000000..4e0c763 --- /dev/null +++ b/osx/proc_posix.go @@ -0,0 +1,33 @@ +package osx + +import ( + "os" + "os/exec" + "strconv" + "syscall" +) + +func isGroupLeader(c *exec.Cmd) bool { + return c.SysProcAttr != nil && c.SysProcAttr.Setpgid +} + +// Pkill kills the process with the given pid, or just -INT if interrupt is true. +func Pkill(pid int, signal os.Signal) error { + signum := signal.(syscall.Signal) + + var err error + defer func() { + if r := recover(); r == nil && err == nil { + return + } + err = exec.Command("pkill", "-"+strconv.Itoa(int(signum)), + "-P", strconv.Itoa(pid)).Run() + }() + err = syscall.Kill(pid, signum) + return err +} + +// GroupKill kills the process group lead by the given pid +func GroupKill(pid int, signal os.Signal) error { + return Pkill(-pid, signal) +} diff --git a/osx/procx.go b/osx/procx.go new file mode 100644 index 0000000..77d8888 --- /dev/null +++ b/osx/procx.go @@ -0,0 +1,56 @@ +package osx + +import ( + "fmt" + "os" + "os/exec" +) + +// KillWithChildren kills the process +// and tries to kill its all children (process group) +func KillWithChildren(p *os.Process, interrupt bool) (err error) { + if p == nil { + return + } + fmt.Println("msg", "killWithChildren", "pid", p.Pid, "interrupt", interrupt) + defer func() { + if r := recover(); r != nil { + Log("msg", "PANIC in kill", "process", p, "error", r) + } + }() + defer p.Release() + + if p.Pid == 0 { + return nil + } + if interrupt { + defer p.Signal(os.Interrupt) + return Pkill(p.Pid, os.Interrupt) + } + defer p.Kill() + + return Pkill(p.Pid, os.Kill) +} + +func groupKill(p *os.Process, interrupt bool) error { + if p == nil { + return nil + } + fmt.Println("msg", "groupKill", "pid", p.Pid) + defer recover() + + if interrupt { + defer p.Signal(os.Interrupt) + return GroupKill(p.Pid, os.Interrupt) + } + defer p.Kill() + + return GroupKill(p.Pid, os.Kill) +} + +func familyKill(cmd *exec.Cmd, interrupt bool) error { + if cmd.SysProcAttr != nil && isGroupLeader(cmd) { + return groupKill(cmd.Process, interrupt) + } + return KillWithChildren(cmd.Process, interrupt) +} diff --git a/paginator/paginator.go b/paginator/paginator.go new file mode 100644 index 0000000..a4efd65 --- /dev/null +++ b/paginator/paginator.go @@ -0,0 +1,169 @@ +package paginator + +import ( + "math" + "net/http" + "net/url" + "strconv" + + "github.com/aljiwala/gutils/converter" +) + +// Paginator ... +type Paginator struct { + Request *http.Request + PerPageNums int + MaxPages int + + nums int64 + pageRange []int + pageNums int + page int +} + +// PageNums ... +func (p *Paginator) PageNums() int { + if p.pageNums != 0 { + return p.pageNums + } + pageNums := math.Ceil(float64(p.nums) / float64(p.PerPageNums)) + if p.MaxPages > 0 { + pageNums = math.Min(pageNums, float64(p.MaxPages)) + } + p.pageNums = int(pageNums) + return p.pageNums +} + +// Nums ... +func (p *Paginator) Nums() int64 { + return p.nums +} + +// SetNums ... +func (p *Paginator) SetNums(nums interface{}) { + p.nums, _ = converter.ToInt64(nums) +} + +// Page ... +func (p *Paginator) Page() int { + if p.page != 0 { + return p.page + } + if p.Request.Form == nil { + p.Request.ParseForm() + } + p.page, _ = strconv.Atoi(p.Request.Form.Get("p")) + if p.page > p.PageNums() { + p.page = p.PageNums() + } + if p.page <= 0 { + p.page = 1 + } + return p.page +} + +// Pages ... +func (p *Paginator) Pages() []int { + if p.pageRange == nil && p.nums > 0 { + var pages []int + pageNums := p.PageNums() + page := p.Page() + switch { + case page >= pageNums-4 && pageNums > 9: + start := pageNums - 9 + 1 + pages = make([]int, 9) + for i := range pages { + pages[i] = start + i + } + case page >= 5 && pageNums > 9: + start := page - 5 + 1 + pages = make([]int, int(math.Min(9, float64(page+4+1)))) + for i := range pages { + pages[i] = start + i + } + default: + pages = make([]int, int(math.Min(9, float64(pageNums)))) + for i := range pages { + pages[i] = i + 1 + } + } + p.pageRange = pages + } + return p.pageRange +} + +// PageLink ... +func (p *Paginator) PageLink(page int) string { + link, _ := url.ParseRequestURI(p.Request.RequestURI) + values := link.Query() + if page == 1 { + values.Del("p") + } else { + values.Set("p", strconv.Itoa(page)) + } + link.RawQuery = values.Encode() + return link.String() +} + +// PageLinkPrev ... +func (p *Paginator) PageLinkPrev() (link string) { + if p.HasPrev() { + link = p.PageLink(p.Page() - 1) + } + return +} + +// PageLinkNext ... +func (p *Paginator) PageLinkNext() (link string) { + if p.HasNext() { + link = p.PageLink(p.Page() + 1) + } + return +} + +// PageLinkFirst ... +func (p *Paginator) PageLinkFirst() (link string) { + return p.PageLink(1) +} + +// PageLinkLast ... +func (p *Paginator) PageLinkLast() (link string) { + return p.PageLink(p.PageNums()) +} + +// HasPrev ... +func (p *Paginator) HasPrev() bool { + return p.Page() > 1 +} + +// HasNext ... +func (p *Paginator) HasNext() bool { + return p.Page() < p.PageNums() +} + +// IsActive ... +func (p *Paginator) IsActive(page int) bool { + return p.Page() == page +} + +// Offset ... +func (p *Paginator) Offset() int { + return (p.Page() - 1) * p.PerPageNums +} + +// HasPages ... +func (p *Paginator) HasPages() bool { + return p.PageNums() > 1 +} + +// NewPaginator ... +func NewPaginator(req *http.Request, per int, nums interface{}) *Paginator { + p := Paginator{} + p.Request = req + if per <= 0 { + per = 10 + } + p.PerPageNums = per + p.SetNums(nums) + return &p +} diff --git a/randx/randx.go b/randx/randx.go new file mode 100644 index 0000000..5108127 --- /dev/null +++ b/randx/randx.go @@ -0,0 +1,113 @@ +package randx + +import ( + "crypto/rand" + "encoding/base64" + "fmt" + mrand "math/rand" + "os" + + "github.com/aljiwala/gutils/convertx" + "github.com/aljiwala/gutils/strx" +) + +// NewRandom creates a new padded Encoding defined by the given alphabet. +func NewRandom(encoderSeed string, ignore ...byte) *Random { + r := new(Random) + r.encoding = base64.NewEncoding(encoderSeed) + r.ignoreMap = map[byte]struct{}{} + if len(ignore) > 0 { + bMap := map[byte]struct{}{} + for i := 0; i < len(encoderSeed); i++ { + bMap[encoderSeed[i]] = struct{}{} + } + for _, b := range ignore { + r.ignoreMap[b] = struct{}{} + delete(bMap, b) + } + r.encoder = r.encoder[:0] + for b := range bMap { + r.encoder = append(r.encoder, b) + } + } else { + r.encoder = []byte(encoderSeed) + } + + r.encoderLen = len(r.encoder) + return r +} + +// Random random string creater. +type Random struct { + encoding *base64.Encoding + encoder []byte + encoderLen int + ignoreMap map[byte]struct{} +} + +// RandomString returns a base64 encoded securely generated +// random string. It will panic if the system's secure random number generator +// fails to function correctly. +// The length n must be an integer multiple of 4, otherwise the last character will be padded with `=`. +func (r *Random) RandomString(n int) string { + d := r.encoding.DecodedLen(n) + buf := make([]byte, r.encoding.EncodedLen(d), n) + r.encoding.Encode(buf, RandomBytes(d)) + if len(r.ignoreMap) > 0 { + var ok bool + for i, b := range buf { + if _, ok = r.ignoreMap[b]; ok { + buf[i] = r.encoder[mrand.Intn(r.encoderLen)] + } + } + } + for i := n - len(buf); i > 0; i-- { + buf = append(buf, r.encoder[mrand.Intn(r.encoderLen)]) + } + + return strx.BytesToString(buf) +} + +const urlEncoder = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_" + +var urlRandom = &Random{ + encoding: base64.URLEncoding, + encoder: []byte(urlEncoder), + encoderLen: len(urlEncoder), + ignoreMap: map[byte]struct{}{}, +} + +// URLRandomString returns a URL-safe, base64 encoded securely generated +// random string. It will panic if the system's secure random number generator +// fails to function correctly. +// The length n must be an integer multiple of 4, otherwise the last character will be padded with `=`. +func URLRandomString(n int) string { + return urlRandom.RandomString(n) +} + +// RandomBytes returns securely generated random bytes. It will panic +// if the system's secure random number generator fails to function correctly. +func RandomBytes(n int) []byte { + b := make([]byte, n) + _, err := rand.Read(b) + // Note that err == nil only if we read len(b) bytes. + if err != nil { + panic(err) + } + return b +} + +// URandom ... +func URandom() string { + f, _ := os.Open("/dev/urandom") + b := make([]byte, 16) + f.Read(b) + f.Close() + + return fmt.Sprintf("%x", b) +} + +// GenerateRandomSeed ... +func GenerateRandomSeed() int64 { + return convertx.BytesToInt64([]byte(URandom())) +} diff --git a/reflectx/reflectx.go b/reflectx/reflectx.go new file mode 100644 index 0000000..5dd733b --- /dev/null +++ b/reflectx/reflectx.go @@ -0,0 +1,24 @@ +package gutils + +import ( + "reflect" + "unicode" + "unicode/utf8" +) + +// IsExportedName should check if it's an exported - upper case - name or not. +func IsExportedName(name string) bool { + rune, _ := utf8.DecodeRuneInString(name) + return unicode.IsUpper(rune) +} + +// IsExportedOrBuiltinType should check if this type is exported or a built-in? +func IsExportedOrBuiltinType(t reflect.Type) bool { + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + + // PkgPath will be non-empty even for an exported type, + // so we need to check the type name as well. + return IsExportedName(t.Name()) || t.PkgPath() == "" +} diff --git a/slicex/slicex.go b/slicex/slicex.go new file mode 100644 index 0000000..2d114d0 --- /dev/null +++ b/slicex/slicex.go @@ -0,0 +1,290 @@ +package slicex + +import ( + "reflect" + "strings" +) + +// takeArg should take arg as interface value and check if it's kind is matched +// with provided kind value. +func takeArg(arg interface{}, kind reflect.Kind) (val reflect.Value, ok bool) { + val = reflect.ValueOf(arg) + if val.Kind() == kind { + ok = true + } + return +} + +// SliceMerge merges interface slices to one slice. +func SliceMerge(slice1, slice2 []interface{}) (c []interface{}) { + c = append(slice1, slice2...) + return +} + +// SliceMergeInt merges int slices to one slice. +func SliceMergeInt(slice1, slice2 []int) (c []int) { + c = append(slice1, slice2...) + return +} + +// SliceMergeInt64 merges int64 slices to one slice. +func SliceMergeInt64(slice1, slice2 []int64) (c []int64) { + c = append(slice1, slice2...) + return +} + +// SliceMergeString merges string slices to one slice. +func SliceMergeString(slice1, slice2 []string) (c []string) { + c = append(slice1, slice2...) + return +} + +// SliceContains would return true if v is present in container. +func SliceContains(container []interface{}, v interface{}) bool { + for _, vv := range container { + if vv == v { + return true + } + } + return false +} + +// SliceContainsInt should return true if int v is present in int container. +func SliceContainsInt(container []int, v int) bool { + for _, vv := range container { + if vv == v { + return true + } + } + return false +} + +// SliceContainsUint should return true if uint v is present in uint container. +func SliceContainsUint(container []uint, v uint) bool { + for _, vv := range container { + if vv == v { + return true + } + } + return false +} + +// SliceContainsString should return true if string v is present in string +// container. +func SliceContainsString(container []string, v string) bool { + for _, vv := range container { + if vv == v { + return true + } + } + return false +} + +// SliceUniqueInt should return int container with unique values. +func SliceUniqueInt(s []int) []int { + size := len(s) + if size == 0 { + return []int{} + } + + m := make(map[int]bool) + for i := 0; i < size; i++ { + m[s[i]] = true + } + + realLen := len(m) + ret := make([]int, realLen) + + idx := 0 + for key := range m { + ret[idx] = key + idx++ + } + + return ret +} + +// SliceUniqueInt64 should return int64 container with unique values. +func SliceUniqueInt64(s []int64) []int64 { + size := len(s) + if size == 0 { + return []int64{} + } + + m := make(map[int64]bool) + for i := 0; i < size; i++ { + m[s[i]] = true + } + + realLen := len(m) + ret := make([]int64, realLen) + + idx := 0 + for key := range m { + ret[idx] = key + idx++ + } + + return ret +} + +// SliceUniqueString should return string container with unique values. +func SliceUniqueString(s []string) []string { + size := len(s) + if size == 0 { + return []string{} + } + + m := make(map[string]bool) + for i := 0; i < size; i++ { + m[s[i]] = true + } + + realLen := len(m) + ret := make([]string, realLen) + + idx := 0 + for key := range m { + ret[idx] = key + idx++ + } + + return ret +} + +// SliceSumInt should return sum of given int container values. +func SliceSumInt(container []int) (sum int) { + for _, v := range container { + sum += v + } + return +} + +// SliceSumInt64 should return sum of given int64 container values. +func SliceSumInt64(container []int64) (sum int64) { + for _, v := range container { + sum += v + } + return +} + +// SliceSumFloat64 should return sum of given float64 container values. +func SliceSumFloat64(container []float64) (sum float64) { + for _, v := range container { + sum += v + } + return +} + +// Dedup should remove duplicate uint values from slice. +func Dedup(uintSlice []uint) (dedupSlice []uint) { + for _, value := range uintSlice { + if !SliceContainsUint(dedupSlice, value) { + dedupSlice = append(dedupSlice, value) + } + } + return +} + +// RemoveDuplicates should remove the duplicate values from slice of string. +// If caseSensitive is true, then, it'll distinguish string "Abc" from "abc". +// +// Input: RemoveDuplicates(&[]string{"abc", "Abc"}, false) -> Output: &[abc] +// Input: RemoveDuplicates(&[]string{"abc", "Abc"}, true) -> Output: &[abc Abc] +func RemoveDuplicates(list *[]string, caseSensitive bool) { + found := make(map[string]bool) + j := 0 + for i, x := range *list { + if !caseSensitive { + x = strings.ToLower(x) + } + if !found[x] { + found[x] = true + (*list)[j] = (*list)[i] + j++ + } + } + *list = (*list)[:j] +} + +// Distinct returns the unique vals of a slice. +func Distinct(arr interface{}) (reflect.Value, bool) { + // Create a slice from our input interface. + slice, ok := takeArg(arr, reflect.Slice) + if !ok { + return reflect.Value{}, ok + } + + // Put the values of our slice into a map the key's of the map will be the + // slice's unique values. + c := slice.Len() + m := make(map[interface{}]bool) + for i := 0; i < c; i++ { + m[slice.Index(i).Interface()] = true + } + + i := 0 + mapLen := len(m) + + // Create the output slice and populate it with the map's keys + out := reflect.MakeSlice(reflect.TypeOf(arr), mapLen, mapLen) + for k := range m { + v := reflect.ValueOf(k) + o := out.Index(i) + o.Set(v) + i++ + } + + return out, ok +} + +// Intersect returns a slice of values that are present in all of the input slices. +func Intersect(arrs ...interface{}) (reflect.Value, bool) { + // Create a map to count all the instances of the slice elems. + arrLength := len(arrs) + var kind reflect.Kind + + tempMap := make(map[interface{}]int) + for i, arg := range arrs { + tempArr, ok := Distinct(arg) + if !ok { + return reflect.Value{}, ok + } + + // check to be sure the type hasn't changed + if i > 0 && tempArr.Index(0).Kind() != kind { + return reflect.Value{}, false + } + kind = tempArr.Index(0).Kind() + + c := tempArr.Len() + for idx := 0; idx < c; idx++ { + if _, ok := tempMap[tempArr.Index(idx).Interface()]; ok { + tempMap[tempArr.Index(idx).Interface()]++ + } else { + tempMap[tempArr.Index(idx).Interface()] = 1 + } + } + } + + // Find the keys equal to the length of the input args. + numElems := 0 + for _, v := range tempMap { + if v == arrLength { + numElems++ + } + } + + i := 0 + out := reflect.MakeSlice(reflect.TypeOf(arrs[0]), numElems, numElems) + for key, val := range tempMap { + if val == arrLength { + v := reflect.ValueOf(key) + o := out.Index(i) + o.Set(v) + i++ + } + } + + return out, true +} diff --git a/sortx/sortx.go b/sortx/sortx.go new file mode 100644 index 0000000..071cf5c --- /dev/null +++ b/sortx/sortx.go @@ -0,0 +1,128 @@ +package gutils + +// QSortT should sort slices type []T from quickly small to large. +func QSortT(arr interface{}, start2End ...int) { + var start, end, low, high int + + switch _arr := arr.(type) { + case []int: + switch len(start2End) { + case 0: + start = 0 + end = len(_arr) - 1 + case 1: + start = start2End[0] + end = len(_arr) - 1 + default: + start = start2End[0] + end = start2End[1] + } + low = start + high = end + key := _arr[start] + + for { + for low < high { + if _arr[high] < key { + _arr[low] = _arr[high] + break + } + high-- + } + for low < high { + if _arr[low] > key { + _arr[high] = _arr[low] + break + } + low++ + } + if low >= high { + _arr[low] = key + break + } + } + + case []uint64: + switch len(start2End) { + case 0: + start = 0 + end = len(_arr) - 1 + case 1: + start = start2End[0] + end = len(_arr) - 1 + default: + start = start2End[0] + end = start2End[1] + } + low = start + high = end + key := _arr[start] + + for { + for low < high { + if _arr[high] < key { + _arr[low] = _arr[high] + break + } + high-- + } + for low < high { + if _arr[low] > key { + _arr[high] = _arr[low] + break + } + low++ + } + if low >= high { + _arr[low] = key + break + } + } + + case []string: + switch len(start2End) { + case 0: + start = 0 + end = len(_arr) - 1 + case 1: + start = start2End[0] + end = len(_arr) - 1 + default: + start = start2End[0] + end = start2End[1] + } + low = start + high = end + key := _arr[start] + + for { + for low < high { + if _arr[high] < key { + _arr[low] = _arr[high] + break + } + high-- + } + for low < high { + if _arr[low] > key { + _arr[high] = _arr[low] + break + } + low++ + } + if low >= high { + _arr[low] = key + break + } + } + default: + return + } + + if low-1 > start { + QSortT(arr, start, low-1) + } + if high+1 < end { + QSortT(arr, high+1, end) + } +} diff --git a/strx/constants.go b/strx/constants.go new file mode 100644 index 0000000..1389b2e --- /dev/null +++ b/strx/constants.go @@ -0,0 +1,115 @@ +package strx + +import "regexp" + +const ( + letterIdxBits = 6 // 6 bits to represent a letter index + letterIdxMask = 1<|\r\n]+\\)*[^\\/:*?"<>|\r\n]*$` + // UnixPath string = `^(/[^/\x00]*)+/?$` + // Semver string = "^v?(?:0|[1-9]\\d*)\\.(?:0|[1-9]\\d*)\\.(?:0|[1-9]\\d*)(-(0|[1-9]\\d*|\\d*[a-zA-Z-][0-9a-zA-Z-]*)(\\.(0|[1-9]\\d*|\\d*[a-zA-Z-][0-9a-zA-Z-]*))*)?(\\+[0-9a-zA-Z-]+(\\.[0-9a-zA-Z-]+)*)?$" + // tagName string = "valid" + + // ----------------------------------------------------------------------------- +) + +var ( + rxEmail = regexp.MustCompile(Email) + // rxCreditCard = regexp.MustCompile(CreditCard) + // rxISBN10 = regexp.MustCompile(ISBN10) + // rxISBN13 = regexp.MustCompile(ISBN13) + // rxUUID3 = regexp.MustCompile(UUID3) + // rxUUID4 = regexp.MustCompile(UUID4) + // rxUUID5 = regexp.MustCompile(UUID5) + // rxUUID = regexp.MustCompile(UUID) + // rxAlpha = regexp.MustCompile(Alpha) + // rxAlphanumeric = regexp.MustCompile(Alphanumeric) + // rxNumeric = regexp.MustCompile(Numeric) + // rxInt = regexp.MustCompile(Int) + // rxFloat = regexp.MustCompile(Float) + // rxHexadecimal = regexp.MustCompile(Hexadecimal) + // rxHexcolor = regexp.MustCompile(Hexcolor) + // rxRGBcolor = regexp.MustCompile(RGBcolor) + // rxASCII = regexp.MustCompile(ASCII) + // rxPrintableASCII = regexp.MustCompile(PrintableASCII) + // rxMultibyte = regexp.MustCompile(Multibyte) + // rxFullWidth = regexp.MustCompile(FullWidth) + // rxHalfWidth = regexp.MustCompile(HalfWidth) + // rxBase64 = regexp.MustCompile(Base64) + // rxDataURI = regexp.MustCompile(DataURI) + // rxLatitude = regexp.MustCompile(Latitude) + // rxLongitude = regexp.MustCompile(Longitude) + // rxDNSName = regexp.MustCompile(DNSName) + + // rxURL is the regular expression of URL. + rxURL = regexp.MustCompile(URL) + + // rxSSN = regexp.MustCompile(SSN) + // rxWinPath = regexp.MustCompile(WinPath) + // rxUnixPath = regexp.MustCompile(UnixPath) + // rxSemver = regexp.MustCompile(Semver) + + // RXURL ... + RXURL = rxURL +) diff --git a/strx/strx.go b/strx/strx.go new file mode 100644 index 0000000..a03c726 --- /dev/null +++ b/strx/strx.go @@ -0,0 +1,300 @@ +package strx + +import ( + "crypto/md5" + "crypto/sha1" + "encoding/hex" + "fmt" + "io" + "math/rand" + "net/url" + "reflect" + "runtime" + "strconv" + "strings" + "time" + "unicode" + "unicode/utf8" + "unsafe" + + "github.com/aljiwala/gutils/slicex" +) + +// CamelCaseSplit splits the camelcase word and returns a list of words. It also +// supports digits. Both lower camel case and upper camel case are supported. +// +// Examples: +// "" => [""] +// "lowercase" => ["lowercase"] +// "Class" => ["Class"] +// "MyClass" => ["My", "Class"] +// "MyC" => ["My", "C"] +// "HTML" => ["HTML"] +// "PDFLoader" => ["PDF", "Loader"] +// "AString" => ["A", "String"] +// "SimpleXMLParser" => ["Simple", "XML", "Parser"] +// "vimRPCPlugin" => ["vim", "RPC", "Plugin"] +// "GL11Version" => ["GL", "11", "Version"] +// "99Bottles" => ["99", "Bottles"] +// "May5" => ["May", "5"] +// "BFG9000" => ["BFG", "9000"] +// "BöseÜberraschung" => ["Böse", "Überraschung"] +// "Two spaces" => ["Two", " ", "spaces"] +// "BadUTF8\xe2\xe2\xa1" => ["BadUTF8\xe2\xe2\xa1"] +// +// Splitting rules +// +// 1) If string is not valid UTF-8, return it without splitting as +// single item array. +// 2) Assign all unicode characters into one of 4 sets: lower case +// letters, upper case letters, numbers, and all other characters. +// 3) Iterate through characters of string, introducing splits +// between adjacent characters that belong to different sets. +// 4) Iterate through array of split strings, and if a given string +// is upper case: +// if subsequent string is lower case: +// move last character of upper case string to beginning of +// lower case string +func CamelCaseSplit(str string) (splitted []string) { + return camelCaseSplit(str) +} + +//////////////////////////////////////////////////////////////////////////////// + +func camelCaseSplit(str string) (container []string) { + // don't split invalid utf8 + if !utf8.ValidString(str) { + return []string{str} + } + + lastClass := 0 + container = []string{} + var ( + class int + runes [][]rune + ) + + // split into fields based on class of unicode character + for _, r := range str { + switch true { + case unicode.IsLower(r): + class = 1 + case unicode.IsUpper(r): + class = 2 + case unicode.IsDigit(r): + class = 3 + default: + class = 4 + } + if class == lastClass { + runes[len(runes)-1] = append(runes[len(runes)-1], r) + } else { + runes = append(runes, []rune{r}) + } + lastClass = class + } + + // handle upper case -> lower case sequences, e.g. + // "PDFL", "oader" -> "PDF", "Loader" + for i := 0; i < len(runes)-1; i++ { + if unicode.IsUpper(runes[i][0]) && unicode.IsLower(runes[i+1][0]) { + runes[i+1] = append([]rune{runes[i][len(runes[i])-1]}, runes[i+1]...) + runes[i] = runes[i][:len(runes[i])-1] + } + } + + // construct []string from results + for _, s := range runes { + if len(s) > 0 { + container = append(container, string(s)) + } + } + + return +} + +// IsNull should check if the string is null. +func IsNull(s string) bool { + return strings.TrimSpace(s) == "" +} + +// IsEmail should check if the string is an email. +func IsEmail(str string) bool { + // TODO uppercase letters are not supported + return rxEmail.MatchString(str) +} + +// TrimAndLowercase should trim space and lower the string. +func TrimAndLowercase(str string) string { + return strings.ToLower(strings.Replace(str, " ", "", -1)) +} + +// TrimRightSpace should return trimmed string after removing newline, tabline, +// et cetera. +func TrimRightSpace(s string) string { + return strings.TrimRight(string(s), "\r\n\t ") +} + +// BytesToString convert []byte type to string type. +func BytesToString(b []byte) string { + return *(*string)(unsafe.Pointer(&b)) +} + +// StringToBytes convert string type to []byte type. +// NOTE: Panics; if modify the member value of the []byte. +func StringToBytes(s string) []byte { + sp := *(*[2]uintptr)(unsafe.Pointer(&s)) + bp := [3]uintptr{sp[0], sp[1], sp[1]} + return *(*[]byte)(unsafe.Pointer(&bp)) +} + +// Md5 should return hash string with MD5 checksum. +func Md5(s string) string { + h := md5.New() + h.Write([]byte(s)) + return fmt.Sprintf("%x", h.Sum(nil)) +} + +// HashStr should return hexadecimal encoding string of new `sha1` checksum. +func HashStr(s string) string { + h := sha1.New() + if _, err := h.Write([]byte(s)); err != nil { + return "" + } + return hex.EncodeToString(h.Sum(nil)) +} + +// HashPassword should return hash string of password by writing salt to it. +func HashPassword(pwd, salt string) string { + h := sha1.New() + io.WriteString(h, salt) + io.WriteString(h, pwd) + return fmt.Sprintf("%x", h.Sum(nil)) +} + +// Left returns the "n" left characters of the string. +// +// If the string is shorter than "n" it will return the first "n" characters of +// the string with "…" appended. Otherwise the entire string is returned as-is. +func Left(s string, n int) string { + if n < 0 { + n = 0 + } + if len(s) <= n { + return s + } + return s[:n] + "…" +} + +// SplitStr should return container of strings splitted by comma value (,). +func SplitStr(str string) (result []string) { + s := strings.Split(str, ",") + for i := range s { + s[i] = strings.TrimSpace(s[i]) + if s[i] != "" { + result = append(result, s[i]) + } + } + return +} + +// SplitAndRemoveDups should split and return unique set of string values. +func SplitAndRemoveDups(str string) (splitted []string) { + splitted = SplitStr(str) + slicex.RemoveDuplicates(&splitted, false) + return +} + +// GenerateRandStr generates the 64-bit long random unique string. Additionaly, +// takes n and strGen to create custom string for the same. +func GenerateRandStr(n int, strGen string) string { + b := make([]byte, n) + randNewSource := rand.NewSource(time.Now().UnixNano()) + + // A randNewSource.Int63() generates 63 random bits, enough for + // letterIdxMax characters! + for i, cache, remain := n-1, randNewSource.Int63(), letterIdxMax; i >= 0; { + if remain == 0 { + cache, remain = randNewSource.Int63(), letterIdxMax + } + if idx := int(cache & letterIdxMask); idx < len(strGen) { + b[i] = strGen[idx] + i-- + } + cache >>= letterIdxBits + remain-- + } + + return string(b) + strconv.FormatInt(time.Now().Unix(), 10) +} + +// SnakeStr converts the accepted string to a snake string (XxYy to xx_yy). +func SnakeStr(s string) string { + data := make([]byte, 0, len(s)*2) + j := false + num := len(s) + + for i := 0; i < num; i++ { + d := s[i] + if i > 0 && d >= 'A' && d <= 'Z' && j { + data = append(data, '_') + } + if d != '_' { + j = true + } + data = append(data, d) + } + + return strings.ToLower(string(data[:])) +} + +// CamelStr converts the accepted string to a camel string (xx_yy to XxYy). +func CamelStr(s string) string { + data := make([]byte, 0, len(s)) + j := false + k := false + num := len(s) - 1 + + for i := 0; i <= num; i++ { + d := s[i] + if k == false && d >= 'A' && d <= 'Z' { + k = true + } + if d >= 'a' && d <= 'z' && (j || k == false) { + d = d - 32 + j = false + k = true + } + if k && d == '_' && num > i && s[i+1] >= 'a' && s[i+1] <= 'z' { + j = true + continue + } + data = append(data, d) + } + + return string(data[:]) +} + +// ObjectName gets the type name of the object. +func ObjectName(obj interface{}) string { + v := reflect.ValueOf(obj) + t := v.Type() + if t.Kind() == reflect.Func { + return runtime.FuncForPC(v.Pointer()).Name() + } + + return t.String() +} + +// JsQueryEscape escapes the string in javascript standard so it can be safely placed +// inside a URL query. +func JsQueryEscape(s string) string { + return strings.Replace(url.QueryEscape(s), "+", "%20", -1) +} + +// JsQueryUnescape does the inverse transformation of JsQueryEscape, converting +// %AB into the byte 0xAB and '+' into ' ' (space). It returns an error if +// any % is not followed by two hexadecimal digits. +func JsQueryUnescape(s string) (string, error) { + return url.QueryUnescape(strings.Replace(s, "%20", "+", -1)) +} diff --git a/strx/strx_test.go b/strx/strx_test.go new file mode 100644 index 0000000..94ba07a --- /dev/null +++ b/strx/strx_test.go @@ -0,0 +1,30 @@ +package strx + +import ( + "fmt" + "testing" +) + +func TestLeft(t *testing.T) { + cases := []struct { + in string + n int + want string + }{ + {"Hello", 100, "Hello"}, + {"Hello", 1, "H…"}, + {"Hello", 5, "Hello"}, + {"Hello", 4, "Hell…"}, + {"Hello", 0, "…"}, + {"Hello", -2, "…"}, + } + + for i, tc := range cases { + t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { + out := Left(tc.in, tc.n) + if out != tc.want { + t.Errorf("\nout: %#v\nwant: %#v\n", out, tc.want) + } + }) + } +} diff --git a/syncx/syncx.go b/syncx/syncx.go new file mode 100644 index 0000000..4e4750b --- /dev/null +++ b/syncx/syncx.go @@ -0,0 +1,22 @@ +package syncx + +import ( + "context" + "sync" +) + +// Wait for a sync.WaitGroup with support for timeout/cancellations from context. +func Wait(ctx context.Context, wg *sync.WaitGroup) error { + ch := make(chan struct{}) + go func() { + defer close(ch) + wg.Wait() + }() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-ch: + return nil + } +} diff --git a/syncx/syncx_test.go b/syncx/syncx_test.go new file mode 100644 index 0000000..272d634 --- /dev/null +++ b/syncx/syncx_test.go @@ -0,0 +1,45 @@ +package syncx + +import ( + "context" + "sync" + "testing" + "time" +) + +func TestWait(t *testing.T) { + var wg sync.WaitGroup + wg.Add(2) + + t.Run("cancel", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + err := Wait(ctx, &wg) + if err != context.Canceled { + t.Errorf("wrong error: %v", err) + } + }) + + t.Run("timeout", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + err := Wait(ctx, &wg) + if err != context.DeadlineExceeded { + t.Errorf("wrong error: %v", err) + } + }) + + t.Run("finish", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + wg.Done() + wg.Done() + + err := Wait(ctx, &wg) + if err != nil { + t.Errorf("wrong error: %v", err) + } + }) +} diff --git a/timex/constants.go b/timex/constants.go new file mode 100644 index 0000000..eb34e81 --- /dev/null +++ b/timex/constants.go @@ -0,0 +1,34 @@ +package timex + +// Represents the number of elements in a given period. +const ( + maxNSecs = 999999999 + secondsPerMinute = 60 + minutesPerHour = 60 + secondsInHour = secondsPerMinute * minutesPerHour + hoursPerDay = 24 + daysPerWeek = 7 + monthsPerQuarter = 3 + monthsPerYear = 12 + yearsPerCenturies = 100 + yearsPerDecade = 10 + weeksPerLongYear = 53 + daysInLeapYear = 366 + daysInNormalYear = 365 + secondsInWeek = secondsPerMinute * minutesPerHour * hoursPerDay * daysPerWeek + secondsInMonth = 2678400 + + SecondsPerMinute = secondsPerMinute + MinutesPerHour = minutesPerHour + HoursPerDay = hoursPerDay + DaysPerWeek = daysPerWeek + MonthsPerQuarter = monthsPerQuarter + MonthsPerYear = monthsPerYear + YearsPerCenturies = yearsPerCenturies + YearsPerDecade = yearsPerDecade + WeeksPerLongYear = weeksPerLongYear + DaysInLeapYear = daysInLeapYear + DaysInNormalYear = daysInNormalYear + SecondsInWeek = secondsInWeek + SecondsInMonth = secondsInMonth +) diff --git a/timex/monthx/constants.go b/timex/monthx/constants.go new file mode 100644 index 0000000..9ac4a97 --- /dev/null +++ b/timex/monthx/constants.go @@ -0,0 +1,57 @@ +// Package monthx is the time related utils. +package monthx + +// Month represents month type as month value in int. +type Month int + +// Standard month values (as an int). +const ( + January Month = iota + 1 + February + March + April + May + June + July + August + September + October + November + December +) + +var ( + // MonthDays is the map of month with total days of it. + MonthDays = map[int]int{ + 1: 31, + 3: 31, + 4: 30, + 5: 31, + 6: 30, + 7: 31, + 8: 31, + 9: 30, + 10: 31, + 11: 30, + 12: 31, + } + + // DaysBefore counts the number of days in a non-leap year + // before month m begins. There is an entry for m=12, counting + // the number of days before January of next year (365). + DaysBefore = [...]int32{ + 0, + 31, + 31 + 28, + 31 + 28 + 31, + 31 + 28 + 31 + 30, + 31 + 28 + 31 + 30 + 31, + 31 + 28 + 31 + 30 + 31 + 30, + 31 + 28 + 31 + 30 + 31 + 30 + 31, + 31 + 28 + 31 + 30 + 31 + 30 + 31 + 31, + 31 + 28 + 31 + 30 + 31 + 30 + 31 + 31 + 30, + 31 + 28 + 31 + 30 + 31 + 30 + 31 + 31 + 30 + 31, + 31 + 28 + 31 + 30 + 31 + 30 + 31 + 31 + 30 + 31 + 30, + 31 + 28 + 31 + 30 + 31 + 30 + 31 + 31 + 30 + 31 + 30 + 31, + } +) diff --git a/timex/monthx/monthx.go b/timex/monthx/monthx.go new file mode 100644 index 0000000..1de235c --- /dev/null +++ b/timex/monthx/monthx.go @@ -0,0 +1,242 @@ +// Package monthx provides functionality, pertaining to months of the year, that +// is not available in standard Go libraries. +package monthx + +import ( + "time" + + "github.com/aljiwala/gutils/timex" +) + +// daysIn should return numeric value as in total days in given month and year. +func daysIn(year int, m time.Month) int { + return time.Date(year, m+1, 0, 0, 0, 0, 0, time.UTC).Day() +} + +// ----------------------------------------------------------------------------- + +// TimeMonth should returns time.Month representation of the given month. +func (m Month) TimeMonth() time.Month { + return time.Month(m) +} + +// Int should return int representation of give month. +func (m Month) Int() int { + return int(m.TimeMonth()) +} + +// String returns an English language based representation. +// e.g. January, February, etc. +func (m Month) String() string { + return m.TimeMonth().String() +} + +// ----------------------------------------------------------------------------- + +// IsJanuary should return true if the given month is January. +func (m Month) IsJanuary() bool { + return IsJanuary(m.TimeMonth()) +} + +// IsFebruary should return true if the given month is February. +func (m Month) IsFebruary() bool { + return IsFebruary(m.TimeMonth()) +} + +// IsMarch should return true if the given month is March. +func (m Month) IsMarch() bool { + return IsMarch(m.TimeMonth()) +} + +// IsApril should return true if the given month is April. +func (m Month) IsApril() bool { + return IsApril(m.TimeMonth()) +} + +// IsMay should return true if the given month is May. +func (m Month) IsMay() bool { + return IsMay(m.TimeMonth()) +} + +// IsJune should return true if the given month is June. +func (m Month) IsJune() bool { + return IsJune(m.TimeMonth()) +} + +// IsJuly should return true if the given month is July. +func (m Month) IsJuly() bool { + return IsJuly(m.TimeMonth()) +} + +// IsAugust should return true if the given month is August. +func (m Month) IsAugust() bool { + return IsAugust(m.TimeMonth()) +} + +// IsSeptember should return true if the given month is September. +func (m Month) IsSeptember() bool { + return IsSeptember(m.TimeMonth()) +} + +// IsOctober should return true if the given month is October. +func (m Month) IsOctober() bool { + return IsOctober(m.TimeMonth()) +} + +// IsNovember should return true if the given month is November. +func (m Month) IsNovember() bool { + return IsNovember(m.TimeMonth()) +} + +// IsDecember should return true if the given month is December. +func (m Month) IsDecember() bool { + return IsDecember(m.TimeMonth()) +} + +// ----------------------------------------------------------------------------- + +// IsJanuary should return true if given month is January. +func IsJanuary(m time.Month) bool { + return m.String() == January.String() +} + +// IsFebruary should return true if given month is February. +func IsFebruary(m time.Month) bool { + return m.String() == February.String() +} + +// IsMarch should return true if given month is March. +func IsMarch(m time.Month) bool { + return m.String() == March.String() +} + +// IsApril should return true if given month is April. +func IsApril(m time.Month) bool { + return m.String() == April.String() +} + +// IsMay should return true if given month is May. +func IsMay(m time.Month) bool { + return m.String() == May.String() +} + +// IsJune should return true if given month is June. +func IsJune(m time.Month) bool { + return m.String() == June.String() +} + +// IsJuly should return true if given month is July. +func IsJuly(m time.Month) bool { + return m.String() == July.String() +} + +// IsAugust should return true if given month is August. +func IsAugust(m time.Month) bool { + return m.String() == August.String() +} + +// IsSeptember should return true if given month is September. +func IsSeptember(m time.Month) bool { + return m.String() == September.String() +} + +// IsOctober should return true if given month is October. +func IsOctober(m time.Month) bool { + return m.String() == October.String() +} + +// IsNovember should return true if given month is November. +func IsNovember(m time.Month) bool { + return m.String() == November.String() +} + +// IsDecember should return true if given month is December. +func IsDecember(m time.Month) bool { + return m.String() == December.String() +} + +// ----------------------------------------------------------------------------- + +// DaysIn should return total number of days of particular month and year. +// +// Example: +// - timex.DaysIn(2020, monthx.June.TimeMonth()) returns 30. +func (m Month) DaysIn(year int) int { + return daysIn(year, m.TimeMonth()) +} + +// LastDay returns the last numeric day of the given month and year. Also takes +// leap years into account. +// +// Example: +// - monthx.June.LastDay(1992) should return 30. +func (m Month) LastDay(year int) int { + return timex.LastDayOfMonth(year, time.Month(m)) +} + +// DaysIn should return total number of days of particular month and year. +// +// Example: +// - timex.DaysIn(2020, monthx.June.TimeMonth()) returns 30. +func DaysIn(year int, month time.Month) int { + return daysIn(year, month) +} + +// StartOfMonth returns the first day of the month of date. +func StartOfMonth(date time.Time) time.Time { + return time.Date(date.Year(), date.Month(), 1, 0, 0, 0, 0, date.Location()) +} + +// EndOfMonth returns the last day of the month of date. +func EndOfMonth(date time.Time) time.Time { + // Go to the next month, then a day of 0 removes a day leaving us at the + // last day of dates month. + return time.Date(date.Year(), date.Month()+1, 0, 0, 0, 0, 0, date.Location()) +} + +// GetStartAndEndOfMonth should return start and end time.Time value by getting +// the same type. +func GetStartAndEndOfMonth(t time.Time) (start, end time.Time) { + year, month, _ := t.Date() + start = time.Date(year, month, 1, 0, 0, 0, 0, t.Location()) + end = start.AddDate(0, 1, -1) + return +} + +// MonthsTo returns the number of months from the current date to the given date. +// The number of months is always rounded down, with a minimal value of 1. +// +// Example: +// - MonthsTo(time.Now().Add(24 * time.Hour * 70)) should return 2. +// +// Dates in the past are not supported, and their behaviour is undefined! +func MonthsTo(a time.Time) int { + var days int + startDate := time.Now() + lastDayOfYear := func(t time.Time) time.Time { + return time.Date(t.Year(), 12, 31, 0, 0, 0, 0, t.Location()) + } + + firstDayOfNextYear := func(t time.Time) time.Time { + return time.Date(t.Year()+1, 1, 1, 0, 0, 0, 0, t.Location()) + } + + cur := startDate + for cur.Year() < a.Year() { + // add 1 to count the last day of the year too. + days += lastDayOfYear(cur).YearDay() - cur.YearDay() + 1 + cur = firstDayOfNextYear(cur) + } + + days += a.YearDay() - cur.YearDay() + if startDate.AddDate(0, 0, days).After(a) { + days-- + } + + months := (days / 30) + if months == 0 { + months = 1 + } + + return months +} diff --git a/timex/monthx/monthx_test.go b/timex/monthx/monthx_test.go new file mode 100644 index 0000000..d50e46b --- /dev/null +++ b/timex/monthx/monthx_test.go @@ -0,0 +1,118 @@ +package monthx + +import ( + "fmt" + "testing" + "time" +) + +// mustParse parses value in the format YYYY-MM-DD failing the test on error. +func mustParse(t *testing.T, value string) time.Time { + const layout = "2006-01-02" + d, err := time.Parse(layout, value) + if err != nil { + t.Fatalf("time.Parse(%q, %q) unexpected error: %v", layout, value, err) + } + return d +} + +func TestStartOfMonth(t *testing.T) { + cases := []struct { + in time.Time + want time.Time + }{ + {mustParse(t, "2016-01-13"), mustParse(t, "2016-01-01")}, + {mustParse(t, "2016-01-01"), mustParse(t, "2016-01-01")}, + {mustParse(t, "2016-12-30"), mustParse(t, "2016-12-01")}, + } + + for _, c := range cases { + got := StartOfMonth(c.in) + if got != c.want { + t.Errorf("StartOfMonth(%s) => %s, want %s", c.in, got, c.want) + } + } +} + +func TestEndOfMonth(t *testing.T) { + cases := []struct { + in time.Time + want time.Time + }{ + {mustParse(t, "2016-01-01"), mustParse(t, "2016-01-31")}, + {mustParse(t, "2016-01-31"), mustParse(t, "2016-01-31")}, + {mustParse(t, "2016-11-01"), mustParse(t, "2016-11-30")}, + {mustParse(t, "2016-12-31"), mustParse(t, "2016-12-31")}, + // Leap test. + {mustParse(t, "2012-02-01"), mustParse(t, "2012-02-29")}, + {mustParse(t, "2013-02-01"), mustParse(t, "2013-02-28")}, + } + + for _, c := range cases { + got := EndOfMonth(c.in) + if got != c.want { + t.Errorf("EndOfMonth(%s) => %s, want %s", c.in, got, c.want) + } + } +} + +func TestMonthsTo(t *testing.T) { + day := 24 * time.Hour + cases := []struct { + in time.Time + want int + }{ + {time.Now(), 1}, + {time.Now().Add(day * 35), 1}, + {time.Now().Add(day * 65), 2}, + {time.Now().Add(day * 370), 12}, + } + + for i, tc := range cases { + t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { + out := MonthsTo(tc.in) + if out != tc.want { + t.Errorf("\nout: %#v\nwant: %#v\n", out, tc.want) + } + }) + } +} + +func TestString(t *testing.T) { + m := Month(January) + monthName := m.String() + if monthName != "January" { + t.Fatal("For", "1", + "expected", "January", + "got", monthName, + ) + } +} + +func TestLastDay(t *testing.T) { + lastDays := [12]int{ + 31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31, + } + + // Normal years + for i := 1; i <= 12; i++ { + m := Month(i) + actual := m.LastDay(2015) + expected := lastDays[i-1] + if actual != expected { + t.Error( + "For", m, + "expected", expected, + "got", actual) + } + } + + // Leap year + actual := Month(February).LastDay(2012) + if actual != 29 { + t.Error( + "For leap year", + "expected", 29, + "got", actual) + } +} diff --git a/timex/timex.go b/timex/timex.go new file mode 100644 index 0000000..a14d9fd --- /dev/null +++ b/timex/timex.go @@ -0,0 +1,853 @@ +package timex + +import ( + "math" + "strings" + "time" +) + +// TimeSpan holds the block of the time period. +type TimeSpan struct { + Begin time.Time + End time.Time +} + +// absValue returns the abs value if needed. +func absValue(needsAbs bool, value int64) int64 { + if needsAbs && value < 0 { + return -value + } + return value +} + +// now returns the current local time. +func now() time.Time { + return time.Now() +} + +// nowIn should return time.Time in provided location. +func nowIn(loc *time.Location) time.Time { + return now().In(loc) +} + +// create returns a new carbon pointe. It is a helper function to create new +// dates. +func create(y int, mon time.Month, d, h, m, s, ns int, l *time.Location) time.Time { + return time.Date(y, mon, d, h, m, s, ns, l) +} + +// isPast determines if the current time is in the past, i.e. less (before) +// than Now(). +func isPast(t time.Time) bool { + return t.Before(Now()) +} + +// isToday should check if the given date is matched with todays date (alias of +// `isSameDay`). +func isToday(t time.Time) bool { + return isSameDay(t) +} + +// isFuture determines if the current time is in the future, i.e. greater (after) +// than Now(). +func isFuture(t time.Time) bool { + return t.After(Now()) +} + +// isCurrentDay determines if the current time is in the current day. +func isCurrentDay(t time.Time) bool { + return t.Day() == Now().Day() +} + +// isCurrentMonth determines if the current time is in the current month. +func isCurrentMonth(t time.Time) bool { + return t.Month() == Now().Month() +} + +// isCurrentYear determines if the current time is in the current year. +func isCurrentYear(t time.Time) bool { + return t.Year() == Now().Year() +} + +// isSameDay checks if the given date is the same day as the current day. +func isSameDay(t time.Time) bool { + n := nowIn(t.Location()) + return t.Year() == n.Year() && t.Month() == n.Month() && t.Day() == n.Day() +} + +// isSameMonth checks if month of the given date is same as month of the current +// date. +func isSameMonth(t time.Time, sameYear bool) bool { + m := nowIn(t.Location()).Month() + if sameYear { + return isSameYear(t) && m == t.Month() + } + return m == t.Month() +} + +// isSameYear checks if given date is in current year. +func isSameYear(t time.Time) bool { + return t.Year() == nowIn(t.Location()).Year() +} + +// addMilliSecond adds a millisecond to the time. +// Positive values travels forward while negative values travels into the past. +func addMilliSecond(t time.Time) time.Time { + return addMilliSeconds(t, 1) +} + +// addMilliSeconds adds milliseconds to the current time. +func addMilliSeconds(t time.Time, s time.Duration) time.Time { + d := time.Duration(s) * time.Millisecond + return t.Add(d) +} + +// addSecond adds a second to the time. +// Positive values travels forward while negative values travels into the past. +func addSecond(t time.Time) time.Time { + return addSeconds(t, 1) +} + +// addSeconds adds seconds to the current time. +func addSeconds(t time.Time, s time.Duration) time.Time { + d := time.Duration(s) * time.Second + return t.Add(d) +} + +// addMinute adds a minute to the time. +// Positive values travels forward while negative values travels into the past. +func addMinute(t time.Time) time.Time { + return addMinutes(t, 1) +} + +// addMinutes adds minutes to the current time. +func addMinutes(t time.Time, m time.Duration) time.Time { + d := time.Duration(m) * time.Minute + return t.Add(d) +} + +// addHour adds an hour to the time. +// Positive values travels forward while negative values travels into the past. +func addHour(t time.Time) time.Time { + return addHours(t, 1) +} + +// addHours adds hours to the current time. +func addHours(t time.Time, h int) time.Time { + d := time.Duration(h) * time.Hour + return t.Add(d) +} + +// addDay adds a day to the time. +// Positive values travels forward while negative values travels into the past. +func addDay(t time.Time) time.Time { + return addDays(t, 1) +} + +// addDays adds days to the current time. +func addDays(t time.Time, d int) time.Time { + return t.AddDate(0, 0, d) +} + +// addWeek adds a week to the time. +// Positive values travels forward while negative values travels into the past. +func addWeek(t time.Time) time.Time { + return addWeeks(t, 1) +} + +// addWeeks adds weeks to the current time. +func addWeeks(t time.Time, w int) time.Time { + return t.AddDate(0, 0, daysPerWeek*w) +} + +// addMonth adds a month to the time. +// Positive values travels forward while negative values travels into the past. +func addMonth(t time.Time) time.Time { + return addMonths(t, 1) +} + +// addMonths adds months to the current time. +func addMonths(t time.Time, m int) time.Time { + return t.AddDate(0, m, 0) +} + +// addQuarter adds a quarter to the time. +// Positive values travels forward while negative values travels into the past. +func addQuarter(t time.Time) time.Time { + return addQuarters(t, 1) +} + +// addQuarters adds quarters to the current time. +func addQuarters(t time.Time, q int) time.Time { + return t.AddDate(0, monthsPerQuarter*q, 0) +} + +// addYear adds a year to the time. +// Positive values travels forward while negative values travels into the past. +func addYear(t time.Time) time.Time { + return addYears(t, 1) +} + +// addYears adds years to the current time. +func addYears(t time.Time, y int) time.Time { + return t.AddDate(y, 0, 0) +} + +// addCentury adds a century to the time. +// Positive values travels forward while negative values travels into the past. +func addCentury(t time.Time) time.Time { + return addCenturies(t, 1) +} + +// addCenturies adds centuries to the current time. +func addCenturies(t time.Time, c int) time.Time { + return t.AddDate(yearsPerCenturies*c, 0, 0) +} + +// ----------------------------------------------------------------------------- + +// subMilliSecond removes a millisecond to the time. +// Positive values travels forward while negative values travels into the past. +func subMilliSecond(t time.Time) time.Time { + return subMilliSeconds(t, 1) +} + +// subMilliSeconds removes milliseconds to the current time. +func subMilliSeconds(t time.Time, s time.Duration) time.Time { + return addMilliSeconds(t, -s) +} + +// subSecond removes a second to the time. +// Positive values travels forward while negative values travels into the past. +func subSecond(t time.Time) time.Time { + return subSeconds(t, -1) +} + +// subSeconds removes seconds to the current time. +func subSeconds(t time.Time, s time.Duration) time.Time { + return addSeconds(t, -s) +} + +// subMinute removes a minute to the time. +// Positive values travels forward while negative values travels into the past. +func subMinute(t time.Time) time.Time { + return subMinutes(t, 1) +} + +// subMinutes removes minutes to the current time. +func subMinutes(t time.Time, m time.Duration) time.Time { + return addMinutes(t, -m) +} + +// subHour removes an hour to the time. +// Positive values travels forward while negative values travels into the past. +func subHour(t time.Time) time.Time { + return subHours(t, 1) +} + +// subHours removes hours to the current time. +func subHours(t time.Time, h int) time.Time { + return addHours(t, -h) +} + +// subDay removes a day to the time. +// Positive values travels forward while negative values travels into the past. +func subDay(t time.Time) time.Time { + return subDays(t, 1) +} + +// subDays removes days to the current time. +func subDays(t time.Time, d int) time.Time { + return addDays(t, -d) +} + +// subWeek removes a week to the time. +// Positive values travels forward while negative values travels into the past. +func subWeek(t time.Time) time.Time { + return subWeeks(t, 1) +} + +// subWeeks removes weeks to the current time. +func subWeeks(t time.Time, w int) time.Time { + return addWeeks(t, -w) +} + +// subMonth removes a month to the time. +// Positive values travels forward while negative values travels into the past. +func subMonth(t time.Time) time.Time { + return subMonths(t, 1) +} + +// addMonths removes months to the current time. +func subMonths(t time.Time, m int) time.Time { + return addMonths(t, -m) +} + +// subQuarter removes a quarter to the time. +// Positive values travels forward while negative values travels into the past. +func subQuarter(t time.Time) time.Time { + return subQuarters(t, 1) +} + +// subQuarters removes quarters to the current time. +func subQuarters(t time.Time, q int) time.Time { + return addQuarters(t, -q) +} + +// subYear removes a year to the time. +// Positive values travels forward while negative values travels into the past. +func subYear(t time.Time) time.Time { + return subYears(t, 1) +} + +// subYears removes years to the current time. +func subYears(t time.Time, y int) time.Time { + return addYears(t, -y) +} + +// subCentury removes a century to the time. +// Positive values travels forward while negative values travels into the past. +func subCentury(t time.Time) time.Time { + return subCenturies(t, 1) +} + +// subCenturies removes centuries to the current time. +func subCenturies(t time.Time, c int) time.Time { + return addCenturies(t, -c) +} + +// ----------------------------------------------------------------------------- + +// isYesterday determines if the current time is yesterday. +func isYesterday(t time.Time) bool { + n := addDay(Now()) + return isSameDay(n) +} + +// isTomorrow determines if the current time is tomorrow. +func isTomorrow(t time.Time) bool { + n := addDay(Now()) + return isSameDay(n) +} + +// isSunday checks if this day is a Sunday. +func isSunday(t time.Time) bool { + return t.Weekday() == time.Sunday +} + +// isMonday checks if this day is a Monday. +func isMonday(t time.Time) bool { + return t.Weekday() == time.Monday +} + +// isTuesday checks if this day is a Tuesday. +func isTuesday(t time.Time) bool { + return t.Weekday() == time.Tuesday +} + +// isWednesday checks if this day is a Wednesday. +func isWednesday(t time.Time) bool { + return t.Weekday() == time.Wednesday +} + +// isThursday checks if this day is a Thursday. +func isThursday(t time.Time) bool { + return t.Weekday() == time.Thursday +} + +// isFriday checks if this day is a Friday. +func isFriday(t time.Time) bool { + return t.Weekday() == time.Friday +} + +// isSaturday checks if this day is a Saturday. +func isSaturday(t time.Time) bool { + return t.Weekday() == time.Saturday +} + +// weekOfMonth returns the week of the month. +func weekOfMonth(t time.Time) int { + w := math.Ceil(float64(t.Day() / daysPerWeek)) + return int(w + 1) +} + +// weekOfYear returns the week of the current year (alias for time.ISOWeek). +func weekOfYear(t time.Time) (int, int) { + return t.ISOWeek() +} + +// isLongYear determines if the instance is a long year. +func isLongYear(t time.Time) bool { + n := create(t.Year(), time.December, 31, 0, 0, 0, 0, t.Location()) + _, w := weekOfYear(n) + return w == weeksPerLongYear +} + +// isLastWeek returns true is the date is within last week. +func isLastWeek(t time.Time) bool { + secondsInWeek := float64(secondsInWeek) + diff := Now().Sub(t) + if diff.Seconds() > 0 && diff.Seconds() < secondsInWeek { + return true + } + return false +} + +// isLastMonth returns true is the date is within last month. +func isLastMonth(t time.Time) bool { + now := Now() + monthDiff := now.Month() - t.Month() + if absValue(true, int64(monthDiff)) != 1 { + return false + } + if now.UnixNano() > t.UnixNano() && monthDiff == 1 { + return true + } + return false +} + +// previousMonthLastDay returns the last day of the previous month. +func previousMonthLastDay(t time.Time) time.Time { + return t.AddDate(0, 0, -t.Day()) +} + +func diffInYears(t time.Time) { + +} + +// Comparision + +// Eq, EqualTo, Ne, NotEqualTo, Gt, GreaterThan, Gte, GreaterThanOrEqualTo, Lt, LessThan, Lte, LessThanOrEqualTo, Between, Closest, Farthest, + +// DiffInMonths +// DiffDurationInString +// DiffInWeeks +// DiffInDays +// DiffInNights +// DiffInSeconds +// DiffInMinutes +// DiffInHours +// SecondsSinceMidnight +// SecondsUntilEndOfDay +// swap +// StartOfDay returns the time at 00:00:00 of the same day +// EndOfDay returns the time at 23:59:59 of the same day +// StartOfMonth +// EndOfMonth +// StartOfQuarter +// EndOfQuarter +// StartOfYear +// EndOfYear +// StartOfDecade +// EndOfDecade +// StartOfCentury +// EndOfCentury +// StartOfWeek +// EndOfWeek +// Next +// NextWeekday +// PreviousWeekday +// NextWeekendDay +// PreviousWeekendDay +// Previous +// FirstOfMonth +// LastOfMonth +// LastDayOfMonth +// FirstDayOfMonth +// NthOfMonth +// FirstOfQuarter +// LastOfQuarter +// NthOfQuarter +// FirstOfYear +// LastOfYear +// NthOfYear +// Average + +// set everything above. SetDay, SetHour, SetMinute, SetSecond + +// isLeapYear determines if current current time is a leap year. +func isLeap(year int) bool { + return year%4 == 0 && (year%100 != 0 || year%400 == 0) +} + +// daysInYear returns the number of days in the year. +func daysInYear(year int) int { + if isLeap(year) { + return daysInLeapYear + } + return daysInNormalYear +} + +// endOfMonth returns the date at the end of the month and time at 23:59:59 +func endOfMonth(year int, month time.Month) time.Time { + return create(year, month+1, 0, 23, 59, 59, maxNSecs, time.Local) +} + +// copy should return similar time instance as provided one. +func copy(t time.Time) time.Time { + return create( + t.Year(), t.Month(), t.Day(), t.Hour(), t.Minute(), t.Second(), + t.Nanosecond(), t.Location(), + ) +} + +// quarter should return the current quarter. +func quarter(month time.Month) int { + switch { + case month < 4: + return 1 + case month >= 4 && month < 7: + return 2 + case month >= 7 && month < 10: + return 3 + } + return 4 +} + +// loadLocation should return location instance based on given name. +func loadLocation(name string) (time.Location, error) { + var ( + loc *time.Location + err error + ) + + name = strings.TrimSpace(name) + if name == "" { + loc, err = time.LoadLocation(time.UTC.String()) + } else { + loc, err = time.LoadLocation(name) + } + if err != nil { + return time.Location{}, err + } + + return *loc, nil +} + +// inLocation should return given time in given location with error (if any). +func inLocation(t time.Time, l string) (time.Time, error) { + loc, err := loadLocation(l) + if err != nil { + return time.Time{}, err + } + return t.In(&loc), nil +} + +// lastDayOfMonth should return last numeric day of the given month and year. +// +// Example: +// - timex.LastDayOfMonth(2018, 2) should return 28. +func lastDayOfMonth(year int, month time.Month) int { + // Special case `February` month. + if month == time.February { + if isLeap(year) { + return 29 + } + return 28 + } + + if month <= 7 { + month++ + } + + if month&0x0001 == 0 { + return 31 + } + + return 30 +} + +// ----------------------------------------------------------------------------- + +// Now returns the current local time. +func Now() time.Time { + return now() +} + +// EndOfMonth should return time value at end of the given month and year. +func EndOfMonth(year int, month time.Month) time.Time { + return endOfMonth(year, month) +} + +// Copy should return similar time instance as provided one. +func Copy(t time.Time) time.Time { + return copy(t) +} + +// Quarter should return the current quarter. +func Quarter(month time.Month) int { + return quarter(month) +} + +// LoadLocation should return location instance based on given name. +func LoadLocation(name string) (time.Location, error) { + return loadLocation(name) +} + +// IsPast determines if the current time is in the past, i.e. less (before) +// than Now(). +func IsPast(t time.Time) bool { + return isPast(t) +} + +// IsFuture determines if the current time is in the future, i.e. greater (after) +// than Now(). +func IsFuture(t time.Time) bool { + return isFuture(t) +} + +// IsLeap determines if current current time is a leap year. +func IsLeap(year int) bool { + return isLeap(year) +} + +// Create returns returns a time.Time from a specific date and time. +// If the location is invalid, it returns an error instead. +func Create(y int, mon time.Month, d, h, m, s, ns int, location string) (time.Time, error) { + l, err := loadLocation(location) + if err != nil { + return time.Time{}, err + } + return create(y, mon, d, h, m, s, ns, &l), nil +} + +// CreateFromDate returns a time.Time from a date. +// The time portion is set to time.Now(). +// If the location is invalid, it returns an error instead. +func CreateFromDate(y int, mon time.Month, d int, location string) (time.Time, error) { + now := now() + h, m, s := now.Clock() + return Create(y, mon, d, h, m, s, now.Nanosecond(), location) +} + +// UnixMilli returns the number of milliseconds elapsed since January 1, 1970 +// UTC. +func UnixMilli() int64 { + return now().UnixNano() / time.Millisecond.Nanoseconds() +} + +// NowIn should return time.Time in provided location. +func NowIn(loc *time.Location) time.Time { + return nowIn(loc) +} + +// NowInLocation returns a current time in given location. +// The location is in IANA Time Zone database, such as "America/New_York". +func NowInLocation(l string) (time.Time, error) { + return InLocation(now(), l) +} + +// InLocation should return given time in given location with error (if any). +func InLocation(t time.Time, l string) (time.Time, error) { + return inLocation(t, l) +} + +// InLocationFormat should return given time in given location with error (if +// any). With given format layout. +func InLocationFormat(t time.Time, locStr, layout string) (string, error) { + t, err := inLocation(t, locStr) + if err != nil { + return "", err + } + return t.Format(layout), nil +} + +// DaysBetween returns the number of whole days between the start date and the +// end date. +func DaysBetween(fromDate, toDate time.Time) int { + return int(toDate.Sub(fromDate) / (24 * time.Hour)) +} + +// LastDayOfMonth should return last numeric day of the given month and year. +// +// Example: +// - timex.LastDayOfMonth(2018, 2) should return 28. +func LastDayOfMonth(year int, month time.Month) int { + return lastDayOfMonth(year, month) +} + +// HoursDiff should return hours' difference between given two dates. +func HoursDiff(t, u time.Time) float64 { + return t.Sub(u).Hours() +} + +// MinutesDiff should return minutes' difference between given two dates. +func MinutesDiff(t, u time.Time) float64 { + return t.Sub(u).Minutes() +} + +// SecondsDiff should return seconds' difference between given two dates. +func SecondsDiff(t, u time.Time) float64 { + return t.Sub(u).Seconds() +} + +// NanosecondsDiff should return nanoseconds' difference between given two dates. +func NanosecondsDiff(t, u time.Time) int64 { + return t.Sub(u).Nanoseconds() +} + +// GetOneDayBeginOfTime returns the begin of the time t. +func GetOneDayBeginOfTime(t time.Time) time.Time { + return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, t.Location()) +} + +// GetOneDayEndOfTime returns the end of the time t. +func GetOneDayEndOfTime(t time.Time) time.Time { + return GetOneDayBeginOfTime(t).Add(24 * time.Hour).Add(-1 * time.Nanosecond) +} + +// TimeBeginningOfWeek return the begin of the week of time t. +// sundayFirst is used to set week day. As in some countries uses `Monday` as +// the first day of the week. +func TimeBeginningOfWeek(t time.Time, sundayFirst bool) time.Time { + weekday := int(t.Weekday()) + if !sundayFirst { + if weekday == 0 { + weekday = 7 + } + weekday = weekday - 1 + } + + d := time.Duration(-weekday) * 24 * time.Hour + t = t.Add(d) + return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, t.Location()) +} + +// TimeEndOfWeek return the end of the week of time t. +// sundayFirst is used to set week day. As in some countries uses `Monday` as +// the first day of the week. +func TimeEndOfWeek(t time.Time, sundayFirst bool) time.Time { + return TimeBeginningOfWeek(t, sundayFirst).AddDate(0, 0, 7).Add(-time.Nanosecond) +} + +// TimeBeginningOfMonth return the begin of the month of time t. +func TimeBeginningOfMonth(t time.Time) time.Time { + year, month, _ := t.Date() + return time.Date(year, month, 1, 0, 0, 0, 0, t.Location()) +} + +// TimeEndOfMonth return the end of the month of time t. +func TimeEndOfMonth(t time.Time) time.Time { + return TimeBeginningOfMonth(t).AddDate(0, 1, -1) +} + +// TimeSubDaysOfTwoDays should return the days bewteen time.Time d1 and +// time.Time d2. +func TimeSubDaysOfTwoDays(d1 time.Time, d2 time.Time) int64 { + ds1 := GetOneDayBeginOfTime(d1) + ds2 := GetOneDayBeginOfTime(d2) + return int64(ds1.Sub(ds2).Hours() / 24) +} + +// Format ---------------------------------------------------------------------- + +// ANSICFormat should return given time in ANSIC format. +// +// Example: "Mon Jan _2 15:04:05 2006" +func ANSICFormat(t time.Time) string { + return t.Format(time.ANSIC) +} + +// UnixDateFormat should return given time in UnixDate format. +// +// Example: "Mon Jan _2 15:04:05 MST 2006" +func UnixDateFormat(t time.Time) string { + return t.Format(time.UnixDate) +} + +// RubyDateFormat should return given time in RubyDate format. +// +// Example: "Mon Jan 02 15:04:05 -0700 2006" +func RubyDateFormat(t time.Time) string { + return t.Format(time.RubyDate) +} + +// RFC822Format should return given time in RFC3339 format. +// +// Example: "02 Jan 06 15:04 MST" +func RFC822Format(t time.Time) string { + return t.Format(time.RFC3339) +} + +// RFC822ZFormat should return given time in RFC822Z (RFC822 with numeric +// zone) format. +// +// Example: "02 Jan 06 15:04 -0700" +func RFC822ZFormat(t time.Time) string { + return t.Format(time.RFC822Z) +} + +// RFC850Format should return given time in RFC850 format. +// +// Example: "Monday, 02-Jan-06 15:04:05 MST" +func RFC850Format(t time.Time) string { + return t.Format(time.RFC850) +} + +// RFC1123Format should return given time in RFC1123 format. +// +// Example: "Mon, 02 Jan 2006 15:04:05 MST" +func RFC1123Format(t time.Time) string { + return t.Format(time.RFC1123) +} + +// RFC1123ZFormat should return given time in RFC1123Z (RFC1123 with numeric +// zone) format. +// +// Example: "Mon, 02 Jan 2006 15:04:05 -0700" +func RFC1123ZFormat(t time.Time) string { + return t.Format(time.RFC1123Z) +} + +// RFC3339Format should return given time in RFC3339 format. +// +// Example: "2006-01-02T15:04:05Z07:00" +func RFC3339Format(t time.Time) string { + return t.Format(time.RFC3339) +} + +// RFC3339NanoFormat should return given time in RFC3339Nano format. +// +// Example: "2006-01-02T15:04:05.999999999Z07:00" +func RFC3339NanoFormat(t time.Time) string { + return t.Format(time.RFC3339Nano) +} + +// KitchenFormat should return given time in Kitchen format. +// +// Example: "3:04PM" +func KitchenFormat(t time.Time) string { + return t.Format(time.Kitchen) +} + +// Handy time stamps. + +// StampFormat should return given time in Stamp format. +// +// Example: "Jan _2 15:04:05" +func StampFormat(t time.Time) string { + return t.Format(time.Stamp) +} + +// StampMilliFormat should return given time in StampMilli format. +// +// Example: "Jan _2 15:04:05.000" +func StampMilliFormat(t time.Time) string { + return t.Format(time.StampMilli) +} + +// StampMicroFormat should return given time in StampMicro format. +// +// Example: "Jan _2 15:04:05.000000" +func StampMicroFormat(t time.Time) string { + return t.Format(time.StampMicro) +} + +// StampNanoFormat should return given time in StampNano format. +// +// Example: "Jan _2 15:04:05.000000000" +func StampNanoFormat(t time.Time) string { + return t.Format(time.StampNano) +} + +// ----------------------------------------------------------------------------- diff --git a/tools/cmdx/cmdx.go b/tools/cmdx/cmdx.go new file mode 100644 index 0000000..c9bb618 --- /dev/null +++ b/tools/cmdx/cmdx.go @@ -0,0 +1,161 @@ +package cmdx + +import ( + "bytes" + "fmt" + "io" + "io/ioutil" + "log" + "os" + "os/exec" + "path/filepath" + "time" +) + +// Log is discarded by default +var Log = func(...interface{}) error { return nil } + +// Loffice executable name +var Loffice = "loffice" + +// Timeout of the child, in seconds +var Timeout = 300 + +// pipeCommands should return a pipe that will be connected to the command's +// standard output when the command starts. +func pipeCommands(commands ...*exec.Cmd) ([]byte, error) { + for i, command := range commands[:len(commands)-1] { + out, err := command.StdoutPipe() + if err != nil { + return nil, err + } + command.Start() + commands[i+1].Stdin = out + } + + final, err := commands[len(commands)-1].Output() + if err != nil { + return nil, err + } + + return final, nil +} + +// OutStr should run given command with provided arguments; returns printed +// output with error (if any). +func OutStr(name string, arg ...string) (outStr string, err error) { + var out bytes.Buffer + cmd := exec.Command(name, arg...) + cmd.Stdout = &out + err = cmd.Run() + outStr = out.String() + return +} + +// OutBytes should run given command with provided arguments; returns printed +// output (as bytes) with error (if any). +func OutBytes(name string, arg ...string) (byteContainer []byte, err error) { + var out bytes.Buffer + cmd := exec.Command(name, arg...) + cmd.Stdout = &out + err = cmd.Run() + byteContainer = out.Bytes() + return +} + +// RunWithTimeout should run cmd with given timeout duration. +func RunWithTimeout(cmd *exec.Cmd, timeout time.Duration) (bool, error) { + var err error + done := make(chan error) + go func() { + done <- cmd.Wait() + }() + + select { + case <-time.After(timeout): // Timeout. + if err = cmd.Process.Kill(); err != nil { + log.Printf("failed to kill: %s, error: %s", cmd.Path, err) // ERROR + } + go func() { + <-done // Allow `goroutine` to exit. + }() + log.Printf("process:%s killed", cmd.Path) // INFO + return true, err + case err = <-done: + return false, err + } +} + +// Convert converts from srcFn to dstFn, into the given format. +// Convert from srcFn to dstFn files, with the given format. +// Either filenames can be empty or "-" which treated as stdin/stdout +func ConvertLoffice(srcFn, dstFn, format string) error { + tempDir, err := ioutil.TempDir("", filepath.Base(srcFn)) + if err != nil { + return fmt.Errorf("cannot create temporary directory: %s", err) + } + defer os.RemoveAll(tempDir) + + if srcFn == "-" || srcFn == "" { + srcFn = filepath.Join(tempDir, "source") + fh, cErr := os.Create(srcFn) + if cErr != nil { + return fmt.Errorf("error creating temp file %q: %s", srcFn, cErr) + } + if _, err = io.Copy(fh, os.Stdin); err != nil { + fh.Close() + return fmt.Errorf("error writing stdout to %q: %s", srcFn, err) + } + fh.Close() + } + + c := exec.Command(Loffice, "--nolockcheck", "--norestore", "--headless", + "--convert-to", format, "--outdir", tempDir, srcFn) + c.Stderr = os.Stderr + c.Stdout = c.Stderr + + Log("msg", "calling", "args", c.Args) + if err = proc.RunWithTimeout(Timeout, c); err != nil { + return fmt.Errorf("error running %q: %s", c.Args, err) + } + + dh, err := os.Open(tempDir) + if err != nil { + return fmt.Errorf("error opening dest dir %q: %s", tempDir, err) + } + defer dh.Close() + + names, err := dh.Readdirnames(3) + if err != nil { + return fmt.Errorf("error listing %q: %s", tempDir, err) + } + if len(names) > 2 { + return fmt.Errorf("too many files in %q: %q", tempDir, names) + } + + var tfn string + for _, fn := range names { + if fn != "source" { + tfn = filepath.Join(dh.Name(), fn) + break + } + } + + src, err := os.Open(tfn) + if err != nil { + return fmt.Errorf("cannot open %q: %s", tfn, err) + } + defer src.Close() + + var dst = io.WriteCloser(os.Stdout) + if !(dstFn == "-" || dstFn == "") { + if dst, err = os.Create(dstFn); err != nil { + return fmt.Errorf("cannot create dest file %q: %s", dstFn, err) + } + } + if _, err = io.Copy(dst, src); err != nil { + return fmt.Errorf("error copying from %v to %v: %v", src, dst, err) + } + + return nil +} diff --git a/utilsx/constants.go b/utilsx/constants.go new file mode 100644 index 0000000..f0f4039 --- /dev/null +++ b/utilsx/constants.go @@ -0,0 +1,16 @@ +package utilsx + +import "strings" + +var ( + // QuoteEscaper should escape quotes from applied string. + QuoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"") +) + +const ( + // minURLRuneCount represents threshold URL rune count (min threshold value). + minURLRuneCount = 3 + + // maxURLRuneCount represents threshold URL rune count (max threshold value). + maxURLRuneCount = 2083 +) diff --git a/utilsx/utilsx.go b/utilsx/utilsx.go new file mode 100644 index 0000000..bfc71e0 --- /dev/null +++ b/utilsx/utilsx.go @@ -0,0 +1,400 @@ +// Package utilsx provides tools and generic functions to use as a util to make +// use if go even easier. +package utilsx + +import ( + "bytes" + "crypto/md5" + "encoding/hex" + "encoding/json" + "fmt" + "hash/crc32" + "hash/fnv" + "io" + "log" + "math/rand" + "mime/multipart" + "net/http" + "net/url" + "os" + "path" + "path/filepath" + "regexp" + "runtime" + "sort" + "strconv" + "strings" + "sync" + "time" + "unicode/utf8" + + "github.com/aljiwala/gutils/strx" +) + +var ( + hostname string + getHostname sync.Once +) + +// WordCount returns counts of each word from given string. +// +// Example: +// - utilsx.WordCount("Australia Canada Germany Australia Japan Canada") should +// return map[Japan:1 Australia:2 Canada:2 Germany:1]. +func WordCount(str string) map[string]int { + return wordCount(str) +} + +// EscapeQuotes should escape quotes from given string str. +func EscapeQuotes(str string) string { + return escapeQuotes(str) +} + +// IsURL should check if the string is an URL. +func IsURL(str string) bool { + return isURL(str) +} + +// GetLastElemOfURL should return last segment/element from given URL string. +func GetLastElemOfURL(urlStr string) (string, error) { + return getLastElemOfURL(urlStr) +} + +// MakeHash should convert/make hash of given string. +func MakeHash(str string) (hash string) { + return makeHashCRC32(str) +} + +// HashString ... +func HashString(encoded string) uint64 { + return hashString(encoded) +} + +// MakeUnique should return eigenvalues' string. +func MakeUnique(obj interface{}) string { + return makeUnique(obj) +} + +// MakeMd5 should return encoded string using MD5 checksum. +func MakeMd5(obj interface{}, length int) string { + return makeMD5(obj, length) +} + +// DayOrdinalSuffix should return suffix based on `day` value provided. +func DayOrdinalSuffix(day int) string { + return dayOrdinalSuffix(day) +} + +// GetTimeByZone should return time.Time value as per provided timezone. +// It will return nil time value and error if `In` method panics, otherwise +// original values. +func GetTimeByZone(t time.Time, zoneStr string) (time.Time, error) { + return getTimeByZone(t, zoneStr) +} + +// MakeMsgID creates a new, globally unique message ID, useable as a Message-ID +// as per RFC822/RFC2822. +func MakeMsgID() string { + return makeMsgID() +} + +// BuildFileRequest should make a request with `Content-Type` as a +// multipart-formdata. +func BuildFileRequest( + urlStr, fieldname, path string, params, headers map[string]string) ( + *http.Request, error) { + return buildFileRequest(urlStr, fieldname, path, params, headers) +} + +// GetEnvWithDefault should return the value of $env from the OS +// and if it's empty, returns default one. +func GetEnvWithDefault(env, def string) (value string) { + return getEnvWithDefault(env, def) +} + +// GetEnvWithDefaultInt return the int value of $env from the OS and +// if it's empty, returns def. +func GetEnvWithDefaultInt(env string, def int) (int, error) { + return getEnvWithDefaultInt(env, def) +} + +// GetEnvWithDefaultBool should return the bool value of $env from the OS +// and if it's empty, returns def. +func GetEnvWithDefaultBool(env string, def bool) (bool, error) { + return getEnvWithDefaultBool(env, def) +} + +// GetEnvWithDefaultDuration return the time duration value of $env from the OS and +// if it's empty, returns def. +func GetEnvWithDefaultDuration(env, def string) (time.Duration, error) { + return getEnvWithDefaultDuration(env, def) +} + +// GetEnvWithDefaultStrings should return a slice of sorted strings from +// the environment or default split on, So "foo,bar" returns ["bar","foo"]. +func GetEnvWithDefaultStrings(env, def string) (v []string) { + return getEnvWithDefaultStrings(env, def) +} + +//////////////////////////////////////////////////////////////////////////////// + +func wordCount(str string) map[string]int { + counts := make(map[string]int) + wordList := strings.Fields(str) + for _, word := range wordList { + _, ok := counts[word] + if ok { + counts[word]++ + } else { + counts[word] = 1 + } + } + return counts +} + +func escapeQuotes(str string) string { + return QuoteEscaper.Replace(str) +} + +func isURL(str string) bool { + str = strings.TrimSpace(str) + if str == "" || strings.HasPrefix(str, ".") || + len(str) <= minURLRuneCount || + utf8.RuneCountInString(str) >= maxURLRuneCount { + return false + } + + u, err := url.Parse(str) + if err != nil { + return false + } + + if strings.HasPrefix(u.Host, ".") { + return false + } + + if u.Host == "" && (u.Path != "" && !strings.Contains(u.Path, ".")) { + return false + } + + return strx.RXURL.MatchString(str) +} + +func getLastElemOfURL(urlStr string) (string, error) { + parsedURL, err := url.Parse(urlStr) + if err != nil { + return "", err + } + return path.Base(parsedURL.Path), nil +} + +func makeHashCRC32(str string) (hash string) { + const IEEE = 0xedb88320 + var IEEETable = crc32.MakeTable(IEEE) + hash = fmt.Sprintf("%x", crc32.Checksum([]byte(str), IEEETable)) + return +} + +func hashString(encoded string) uint64 { + hash := fnv.New64() + hash.Write([]byte(encoded)) + return hash.Sum64() +} + +func makeUnique(obj interface{}) string { + baseString, _ := json.Marshal(obj) + return strconv.FormatUint(hashString(string(baseString)), 10) +} + +func makeMD5(obj interface{}, length int) string { + if length > 32 { + length = 32 + } + h := md5.New() + baseString, _ := json.Marshal(obj) + h.Write([]byte(baseString)) + s := hex.EncodeToString(h.Sum(nil)) + return s[:length] +} + +func dayOrdinalSuffix(day int) string { + switch day % 10 { + case 1: + return "st" + case 2: + return "nd" + case 3: + return "rd" + default: + return "th" + } +} + +func getTimeByZone(t time.Time, zoneStr string) (time.Time, error) { + loc, err := time.LoadLocation(zoneStr) + if err != nil { + return time.Time{}, err + } + return t.In(loc), nil +} + +func makeMsgID() string { + getHostname.Do(func() { + var err error + if hostname, err = os.Hostname(); err != nil { + log.Printf("ERROR: Get hostname: %v", err) + hostname = "localhost" + } + }) + + now := time.Now() + return fmt.Sprintf( + "<%d.%d.%d@%s>", now.Unix(), now.UnixNano(), rand.Int63(), hostname, + ) +} + +func buildFileRequest( + urlStr, fieldname, path string, params, headers map[string]string) ( + *http.Request, error) { + // Opens the named file for reading. + file, err := os.Open(path) + if err != nil { + return nil, err + } + defer file.Close() + + body := &bytes.Buffer{} + writer := multipart.NewWriter(body) + part, err := writer.CreateFormFile(fieldname, filepath.Base(path)) + if err != nil { + return nil, err + } + + // Write to destination file from source file. + if _, cpErr := io.Copy(part, file); cpErr != nil { + return nil, cpErr + } + + // Range over given params and write the given value. + for key, val := range params { + _ = writer.WriteField(key, val) + } + + // Write the trailing boundary to end line to output and close it. + if cErr := writer.Close(); cErr != nil { + return nil, cErr + } + + req, err := http.NewRequest(http.MethodPost, urlStr, body) + req.Header.Add("Content-Type", writer.FormDataContentType()) + // Set provided header values. + for key, value := range headers { + req.Header.Set(key, value) + } + + return req, err +} + +func getEnvWithDefault(env, def string) string { + v := os.Getenv(env) + if v == "" { + return def + } + + return v +} + +func getEnvWithDefaultInt(env string, def int) (int, error) { + v := os.Getenv(env) + if v == "" { + return def, nil + } + + return strconv.Atoi(v) +} + +func getEnvWithDefaultBool(env string, def bool) (bool, error) { + v := os.Getenv(env) + if v == "" { + return def, nil + } + + return strconv.ParseBool(v) +} + +func getEnvWithDefaultDuration(env, def string) (time.Duration, error) { + v := os.Getenv(env) + if v == "" { + v = def + } + + return time.ParseDuration(v) +} + +func getEnvWithDefaultStrings(env, def string) (v []string) { + env = GetEnvWithDefault(env, def) + if env == "" { + return make([]string, 0) + } + + v = strings.Split(env, ",") + if !sort.StringsAreSorted(v) { + sort.Strings(v) + } + + return v +} + +// // SaveToFile should copy the contents of given file to destination file. +// func SaveToFile(r *http.Request, fromFile string, path string) error { +// file, _, err := r.FormFile(fromFile) +// if err != nil { +// return err +// } +// defer file.Close() +// +// f, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0666) +// if err != nil { +// return err +// } +// defer f.Close() +// +// _, cErr := io.Copy(f, file) +// if cErr != nil { +// return cErr +// } +// +// return nil +// } + +// // GetContentFromURL should get the response from specified URL. +// func GetContentFromURL(url string) (*http.Response, error) { +// response, err := http.Get(url) +// if err != nil { +// return nil, err +// } +// +// return response, nil +// } + +//////////////////////////////////////////////////////////////////////////////// + +// TimeTrack will print the execution time of the function. +// Possible Usage(s): +// - Call `TimeTrack` function using defer statement. +// +// Ref: https://stackoverflow.com/a/45773638/4039768 +func TimeTrack(start time.Time) { + // Skip this function, and fetch the PC and file for its parent. + pc, _, _, _ := runtime.Caller(1) + // Retrieve a function object this functions parent. + funcObj := runtime.FuncForPC(pc) + + // Regex to extract just the function name (and not the module path). + runtimeFunc := regexp.MustCompile(`^.*\.(.*)$`) + funcName := runtimeFunc.ReplaceAllString(funcObj.Name(), "$1") + + log.Printf("%s took %s", funcName, time.Since(start)) +} + +//////////////////////////////////////////////////////////////////////////////// diff --git a/utilsx/utilsx_test.go b/utilsx/utilsx_test.go new file mode 100644 index 0000000..61a0daa --- /dev/null +++ b/utilsx/utilsx_test.go @@ -0,0 +1,32 @@ +package utilsx + +import ( + "fmt" + "testing" +) + +func TestWordCount(t *testing.T) { + case1Want := make(map[string]int, 4) + case1Want["Australia"] = 2 + case1Want["Canada"] = 2 + case1Want["Germany"] = 1 + case1Want["Japan"] = 1 + + cases := []struct { + in string + want map[string]int + }{ + {"Australia Canada Germany Australia Japan Canada", case1Want}, + } + + for i, tc := range cases { + t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { + out := WordCount(tc.in) + for i, ii := range out { + if case1Want[i] != ii { + t.Errorf("\nout: %#v\nwant: %#v\n", ii, case1Want[i]) + } + } + }) + } +}