From e1951899211e97d337c6f13941eb07a37784f6ae Mon Sep 17 00:00:00 2001 From: nighca Date: Mon, 20 Jan 2025 16:53:30 +0800 Subject: [PATCH] fix `importName` for multi-import in single file --- import.go | 26 ++++++++++++++++--- package.go | 7 +++--- package_test.go | 67 +++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 93 insertions(+), 7 deletions(-) diff --git a/import.go b/import.go index bbf1651..0cda439 100644 --- a/import.go +++ b/import.go @@ -583,9 +583,16 @@ func (p *Package) big() PkgRef { // ---------------------------------------------------------------------------- type null struct{} + +type importName struct { + name string + file string +} + type autoNames struct { - names map[string]null - autoIdx int + names map[string]null + importNames map[importName]null + autoIdx int } const ( @@ -594,6 +601,7 @@ const ( func (p *autoNames) init() { p.names = make(map[string]null) + p.importNames = make(map[importName]null) } func (p *autoNames) autoName() string { @@ -610,14 +618,24 @@ func (p *autoNames) hasName(name string) bool { return ok } -func (p *autoNames) importName(name string) (ret string, renamed bool) { +func (p *autoNames) useImportName(file, name string) { + p.importNames[importName{name, file}] = null{} +} + +func (p *autoNames) hasImportName(file, name string) bool { + _, ok := p.importNames[importName{name, file}] + return ok +} + +func (p *autoNames) importName(file, name string) (ret string, renamed bool) { ret = name var idx int - for p.hasName(ret) { + for p.hasName(ret) || p.hasImportName(file, ret) { idx++ ret = name + strconv.Itoa(idx) renamed = true } + p.useImportName(file, ret) return } diff --git a/package.go b/package.go index aaa9a10..7634dd5 100644 --- a/package.go +++ b/package.go @@ -167,7 +167,7 @@ func (p *File) forceImport(pkgPath string) { func (p *File) markUsed(this *Package) { if p.dirty { - astVisitor{this}.markUsed(p.decls) + astVisitor{this, p}.markUsed(p.decls) p.dirty = false } } @@ -178,7 +178,8 @@ func (p *File) Name() string { } type astVisitor struct { - this *Package + pkg *Package + file *File } func (p astVisitor) Visit(node ast.Node) (w ast.Visitor) { @@ -192,7 +193,7 @@ func (p astVisitor) Visit(node ast.Node) (w ast.Visitor) { if id, ok := x.(*ast.Ident); ok && id.Obj != nil { if used, ok := id.Obj.Data.(importUsed); ok && bool(!used) { id.Obj.Data = importUsed(true) - if name, renamed := p.this.importName(id.Name); renamed { + if name, renamed := p.pkg.importName(p.file.Name(), id.Name); renamed { id.Name = name id.Obj.Name = name } diff --git a/package_test.go b/package_test.go index 4e57670..9ddf150 100644 --- a/package_test.go +++ b/package_test.go @@ -2621,6 +2621,73 @@ func demo() { `, "b.go") } +type FmtPatchImporter struct{} + +func (m *FmtPatchImporter) Import(path string) (pkg *types.Package, err error) { + if path == "fmt@patch" { + fmtPkg, _ := gblImp.Import("fmt") + fmtPrintfObj := fmtPkg.Scope().Lookup("Printf") + patchPkg := types.NewPackage(path, "fmt") + patchPrintfObj := types.NewFunc(fmtPrintfObj.Pos(), patchPkg, "Printf", fmtPrintfObj.Type().(*types.Signature)) + patchPkg.Scope().Insert(patchPrintfObj) + return patchPkg, nil + } + return gblImp.Import(path) +} + +// TestPatchImport tests importing behavior for patched package. +// See related issue: https://github.com/goplus/igop/pull/275 +func TestPatchImport(t *testing.T) { + pkg := gogen.NewPackage("", "main", &gogen.Config{ + Fset: gblFset, + Importer: &FmtPatchImporter{}, + Recorder: eventRecorder{}, + NodeInterpreter: nodeInterp{}, + DbgPositioner: nodeInterp{}, + }) + + _, err := pkg.SetCurFile("a.go", true) + if err != nil { + t.Fatal("pkg.SetCurFile failed:", err) + } + fmt := pkg.Import("fmt") + fmt2 := pkg.Import("fmt@patch") + pkg.NewFunc(nil, "main", nil, nil, false).BodyStart(pkg). + Val(fmt.Ref("Println")).Val("Hello").Call(1).EndStmt(). + Val(fmt2.Ref("Printf")).Val("Hello").Call(1).EndStmt(). + End() + + _, err = pkg.SetCurFile("b.go", true) + if err != nil { + t.Fatal("pkg.SetCurFile failed:", err) + } + pkg.NewFunc(nil, "demo", nil, nil, false).BodyStart(pkg). + Val(fmt.Ref("Println")).Val("Hello").Call(1).EndStmt(). + End() + + domTestEx(t, pkg, `package main + +import ( + "fmt" + fmt1 "fmt@patch" +) + +func main() { + fmt.Println("Hello") + fmt1.Printf("Hello") +} +`, "a.go") + + domTestEx(t, pkg, `package main + +import "fmt" + +func demo() { + fmt.Println("Hello") +} +`, "b.go") +} + func TestImportMultiFiles(t *testing.T) { pkg := newMainPackage()