diff --git a/formatter/config.go b/formatter/config.go index 7024e1a..856fc61 100644 --- a/formatter/config.go +++ b/formatter/config.go @@ -5,7 +5,16 @@ type Config struct { MaxLineLength uint8 `mapstructure:"max-line-length"` Highlight bool + Spaces Spaces + Alignment Alignment `mapstructure:"alignment"` +} + +func DefaultConfig() Config { + return Config{ + IndentSize: 4, + MaxLineLength: 80, + } } type Spaces struct { @@ -16,3 +25,26 @@ type Around struct { UnaryOperator bool MultiplicativeOperator bool } + +type Alignment struct { + Table AlignmentTable `mapstructure:"table"` +} + +// AlignmentTable formatting tables in code +type AlignmentTable struct { + // KeyValue = true + // t = { + // key1 = value1, + // key10 = value10, + // key100 = value100, + // } + KeyValuePairs bool `mapstructure:"key-value-pairs"` + + // Comments = true + // t = { + // key1 = value1, -- comment + // key10 = value10, -- comment + // key100 = value100, -- comment + // } + Comments bool `mapstructure:"comments"` +} diff --git a/formatter/document.go b/formatter/document.go index 418ac4f..864ed8e 100644 --- a/formatter/document.go +++ b/formatter/document.go @@ -21,6 +21,8 @@ type document struct { Body statement } +type isBreak bool + type statement interface { Append(*element) AppendStatement(statement) diff --git a/formatter/exp.go b/formatter/exp.go index c642f7e..d718f8d 100644 --- a/formatter/exp.go +++ b/formatter/exp.go @@ -437,6 +437,10 @@ func (s *exp) GetStatement(prev, cur *element) statement { func (s *exp) Format(c *Config, p printer, w io.Writer) error { if s.Comments != nil { for i := 0; i < len(s.Comments); i++ { + if s.Comments[uint64(i)].Token.Type == nComment { + continue + } + if _, err := w.Write([]byte("--[[ ")); err != nil { return err } diff --git a/formatter/field.go b/formatter/field.go index 93b63f6..75a9e28 100644 --- a/formatter/field.go +++ b/formatter/field.go @@ -14,7 +14,9 @@ package formatter -import "io" +import ( + "io" +) type field struct { Key *exp @@ -130,6 +132,12 @@ func (s *field) Format(c *Config, p printer, w io.Writer) error { return nil } + if p.SpacesBeforeAssign > 0 { + if err := p.WriteSpaces(w, int(p.SpacesBeforeAssign)); err != nil { + return err + } + } + if _, err := w.Write([]byte(" = ")); err != nil { return err } diff --git a/formatter/fieldlist.go b/formatter/fieldlist.go index ba3a108..848a43b 100644 --- a/formatter/fieldlist.go +++ b/formatter/fieldlist.go @@ -14,7 +14,10 @@ package formatter -import "io" +import ( + "bytes" + "io" +) type fieldlist struct { List []*field @@ -62,14 +65,55 @@ func (s *fieldlist) GetStatement(prev, cur *element) statement { } func (s *fieldlist) Format(c *Config, p printer, w io.Writer) error { + var fl map[uint64]fieldLength + + t := c.Alignment.Table + if t.KeyValuePairs || t.Comments { + if p.ParentStatement == tsTable { + fl = s.Align(c, p) + } + } + for i, v := range s.List { + if v.Key.Element == nil && + v.Key.Table == nil && + v.Key.Func == nil && + v.Key.Binop == nil && + v.Key.Unop == nil && + v.Key.Exp == nil && + v.Key.Prefixexp == nil { + continue + } + if p.ParentStatement == tsTable { if err := p.WritePad(w); err != nil { return err } } - if err := v.Format(c, p, w); err != nil { + if i == 0 && v.Key.Comments != nil { + for _, com := range v.Key.Comments { + if com.Token.Type != nComment { + break + } + + if _, err := w.Write([]byte("-- ")); err != nil { + return err + } + + if err := com.Format(c, p, w); err != nil { + return err + } + + if err := newLine(w); err != nil { + return err + } + } + } + + fieldPrinter := p + fieldPrinter.SpacesBeforeAssign = fl[uint64(i)].Key + if err := v.Format(c, fieldPrinter, w); err != nil { return err } @@ -86,6 +130,23 @@ func (s *fieldlist) Format(c *Config, p printer, w io.Writer) error { return err } + if i+1 < len(s.List) { + com := s.List[i+1].Key.Comments + if com != nil && len(com) > 0 && com[0].Token.Type == nComment { + if err := p.WriteSpaces(w, int(fl[uint64(i)].Val)); err != nil { + return err + } + + if _, err := w.Write([]byte(" -- ")); err != nil { + return err + } + + if _, err := w.Write(com[0].Token.Lexeme); err != nil { + return err + } + } + } + if err := newLine(w); err != nil { return err } @@ -94,3 +155,78 @@ func (s *fieldlist) Format(c *Config, p printer, w io.Writer) error { return nil } + +type fieldLength struct { + Key uint8 + Val uint8 +} + +func (s *fieldlist) Align(c *Config, p printer) map[uint64]fieldLength { + var ( + MaxKeyLength uint8 + MaxValueLength uint8 + + res = make(map[uint64]fieldLength) + + alignBlock = make(map[uint64]fieldLength) + w = bytes.NewBuffer(nil) + ) + + for i := 0; i < len(s.List); i++ { + item := s.List[i] + + if item.Val != nil && item.Val.Func != nil { + for b, v := range alignBlock { + res[b] = fieldLength{ + Key: MaxKeyLength - v.Key, + Val: MaxValueLength - v.Val, + } + } + + alignBlock = make(map[uint64]fieldLength) + MaxKeyLength = 0 + MaxValueLength = 0 + + continue + } + + if s.List[i].Square { + w.WriteString("[]") + } + + if err := s.List[i].Key.Format(c, p, w); err != nil { + return res + } + + kl := uint8(w.Len()) + w.Reset() + + if s.List[i].Val != nil { + if err := s.List[i].Val.Format(c, p, w); err != nil { + return res + } + } + + vl := uint8(w.Len()) + w.Reset() + + alignBlock[uint64(i)] = fieldLength{Key: kl, Val: vl} + + if MaxKeyLength < kl { + MaxKeyLength = kl + } + + if MaxValueLength < vl { + MaxValueLength = vl + } + } + + for b, v := range alignBlock { + res[b] = fieldLength{ + Key: MaxKeyLength - v.Key, + Val: MaxValueLength - v.Val, + } + } + + return res +} diff --git a/formatter/formatter.go b/formatter/formatter.go index 8f1a750..7f0f772 100644 --- a/formatter/formatter.go +++ b/formatter/formatter.go @@ -49,21 +49,22 @@ func Format(c Config, b []byte, w io.Writer) error { return nil } -func DefaultConfig() Config { - return Config{ - IndentSize: 4, - MaxLineLength: 80, - } -} - type printer struct { ParentStatement typeStatement Pad uint8 - IgnoreFirstPad bool + + SpacesBeforeAssign uint8 + SpacesBeforeComment uint8 + + IgnoreFirstPad bool } func (p printer) WritePad(w io.Writer) error { - b := bytes.Repeat([]byte(" "), int(p.Pad)) + return p.WriteSpaces(w, int(p.Pad)) +} + +func (p printer) WriteSpaces(w io.Writer, count int) error { + b := bytes.Repeat([]byte(" "), count) _, err := w.Write(b) return err diff --git a/formatter/formatter_test.go b/formatter/formatter_test.go index 1bd9148..0cac3e2 100644 --- a/formatter/formatter_test.go +++ b/formatter/formatter_test.go @@ -527,3 +527,73 @@ end }) } } + +func TestFormat_withConfigTable(t *testing.T) { + type args struct { + c Config + b []byte + } + tests := []struct { + name string + args args + wantW string + wantErr bool + }{ + { + name: "assignment statement", + args: args{ + c: Config{ + IndentSize: 4, + MaxLineLength: 80, + Alignment: Alignment{ + Table: AlignmentTable{ + KeyValuePairs: true, + Comments: true, + }, + }, + }, + b: []byte(` +table = { + ["a()"] = false, -- comm 1 + [1+1] = true, -- comm 2 + bb = function () return 1 end, -- comm 3 + ["1394-E"] = val1, -- comm 4 + ["UTF-8"] = val2, -- comm 5 + ["and"] = val3, -- comm 6 + [true] = 1, -- comm 7 + aa = nil, -- comm 8 +} +`), + }, + wantW: ` +table = { + ["a()"] = false, -- comm 1 + [1 + 1] = true, -- comm 2 + bb = function() + return 1 + end, -- comm 3 + ["1394-E"] = val1, -- comm 4 + ["UTF-8"] = val2, -- comm 5 + ["and"] = val3, -- comm 6 + [true] = 1, -- comm 7 + aa = nil, -- comm 8 +} +`, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := &bytes.Buffer{} + w.Write([]byte("\n")) + if err := Format(tt.args.c, tt.args.b, w); (err != nil) != tt.wantErr { + t.Errorf("Format() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if !assert.Equal(t, tt.wantW, w.String()) { + t.Error("failed to format") + } + }) + } +} diff --git a/formatter/lexer.go b/formatter/lexer.go index c4d7bb3..6d1e2b7 100644 --- a/formatter/lexer.go +++ b/formatter/lexer.go @@ -42,9 +42,9 @@ func newScanner(code []byte) (*scanner, error) { lexer.Add([]byte(`--\[\[([^\]\]])*\]\]`), commentLong(nCommentLong)) lexer.Add([]byte(`::([^::])*::`), token(nLabel)) - lexer.Add([]byte(`(")[^(")]*(")`), token(nString)) - lexer.Add([]byte(`(')[^(')]*(')`), token(nString)) - lexer.Add([]byte(`(\[\[)[^(\]\])]*(\]\])`), token(nString)) + lexer.Add([]byte(`"[^"]*"`), token(nString)) + lexer.Add([]byte(`'[^']*'`), token(nString)) + lexer.Add([]byte(`\[\[[^\]\]]*\]\]`), token(nString)) if err := lexer.Compile(); err != nil { return nil, err