Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🚀 performance: add concurrency to all search/find functionality, lazy imports, etc #217

Merged
merged 23 commits into from
Nov 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Added

- Added `obsidian.async` module for internal use.

### Changed

- Re-implemented the native Lua YAML parser (`obsidian.yaml.native`). This should be faster and more robust now. 🤠
- Re-implemented search/find functionality to utilize concurrency via `obsidian.async` and `plenary.async` for big performance gains. 🏎️
- Submodules imported lazily.
- Changes to internal module organization.

### Fixed

Expand Down
2 changes: 1 addition & 1 deletion RELEASE_PROCESS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

## Steps

1. Update the `VERSION` field in `lua/obsidian/init.lua`.
1. Update the version in `lua/obsidian/version.lua`.

3. Run the release script:

Expand Down
90 changes: 48 additions & 42 deletions lua/cmp_obsidian.lua
Original file line number Diff line number Diff line change
Expand Up @@ -19,58 +19,64 @@ source.complete = function(self, request, callback)
local can_complete, search, insert_start, insert_end = completion.can_complete(request)

if can_complete and search ~= nil and #search >= opts.completion.min_chars then
local items = {}
for note in client:search(search, "--ignore-case") do
local aliases = util.unique { tostring(note.id), note:display_name(), unpack(note.aliases) }
for _, alias in pairs(aliases) do
local options = {}
local function search_callback(results)
local items = {}
for _, result in ipairs(results) do
local note = result[1]
local aliases = util.unique { tostring(note.id), note:display_name(), unpack(note.aliases) }
for _, alias in pairs(aliases) do
local options = {}

local alias_case_matched = util.match_case(search, alias)
if
alias_case_matched ~= nil
and alias_case_matched ~= alias
and not util.contains(note.aliases, alias_case_matched)
then
table.insert(options, alias_case_matched)
end
local alias_case_matched = util.match_case(search, alias)
if
alias_case_matched ~= nil
and alias_case_matched ~= alias
and not util.contains(note.aliases, alias_case_matched)
then
table.insert(options, alias_case_matched)
end

table.insert(options, alias)
table.insert(options, alias)

for _, option in pairs(options) do
local label = "[[" .. tostring(note.id)
if option ~= tostring(note.id) then
label = label .. "|" .. option .. "]]"
else
label = label .. "]]"
end
for _, option in pairs(options) do
local label = "[[" .. tostring(note.id)
if option ~= tostring(note.id) then
label = label .. "|" .. option .. "]]"
else
label = label .. "]]"
end

table.insert(items, {
sortText = "[[" .. option,
label = label,
kind = 18,
textEdit = {
newText = label,
range = {
start = {
line = request.context.cursor.row - 1,
character = insert_start,
},
["end"] = {
line = request.context.cursor.row - 1,
character = insert_end,
table.insert(items, {
sortText = "[[" .. option,
label = label,
kind = 18,
textEdit = {
newText = label,
range = {
start = {
line = request.context.cursor.row - 1,
character = insert_start,
},
["end"] = {
line = request.context.cursor.row - 1,
character = insert_end,
},
},
},
},
})
})
end
end
end

callback {
items = items,
isIncomplete = false,
}
end
return callback {
items = items,
isIncomplete = false,
}

client:search_async(search, { "--ignore-case" }, search_callback)
else
return callback { isIncomplete = true }
callback { isIncomplete = true }
end
end

Expand Down
252 changes: 252 additions & 0 deletions lua/obsidian/async.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
local async = require "plenary.async"
local channel = require("plenary.async.control").channel
local uv = vim.loop

local M = {}

---An abstract class that mimics Python's `concurrent.futures.Executor` class.
---@class obsidian.Executor
---@field tasks_running integer
local Executor = {}

---@return obsidian.Executor
Executor.new = function()
local self = setmetatable({}, { __index = Executor })
self.tasks_running = 0
return self
end

---Submit a one-off function with a callback to the thread pool.
---
---@param self obsidian.Executor
---@param fn function
---@param callback function|?
---@diagnostic disable-next-line: unused-local,unused-vararg
Executor.submit = function(self, fn, callback, ...)
error "not implemented"
end

---Map a function over a generator or array of task args. The callback is called with an array of the results
---once all tasks have finished. The order of the results passed to the callback will be the same
---as the order of the corresponding task args.
---
---@param self obsidian.Executor
---@param fn function
---@param task_args table[]|function
---@param callback function|?
---@diagnostic disable-next-line: unused-local
Executor.map = function(self, fn, task_args, callback)
local results = {}
local num_tasks = 0
local tasks_completed = 0
local all_submitted = false
local tx, rx = channel.oneshot()

local function collect_results()
rx()
return results
end

local function get_task_done_fn(i)
return function(...)
tasks_completed = tasks_completed + 1
results[i] = { ... }
if all_submitted and tasks_completed == num_tasks then
tx()
end
end
end

if type(task_args) == "table" then
num_tasks = #task_args

for i, args in ipairs(task_args) do
self:submit(fn, get_task_done_fn(i), unpack(args))
end
elseif type(task_args) == "function" then
local i = 0
local args = { task_args() }
local next_args = { task_args() }
while args[1] ~= nil do
if next_args[1] == nil then
all_submitted = true
end
i = i + 1
num_tasks = num_tasks + 1
self:submit(fn, get_task_done_fn(i), unpack(args))
args = next_args
next_args = { task_args() }
end
end

if num_tasks == 0 then
if callback ~= nil then
callback {}
end
else
async.run(collect_results, callback and callback or function(_) end)
end
end

---@param self obsidian.Executor
---@param timeout integer|?
---@param pause_fn function(integer)
Executor._join = function(self, timeout, pause_fn)
local start_time = uv.uptime()
local pause_for = 100
if timeout ~= nil then
pause_for = math.min(timeout / 2, pause_for)
end
while self.tasks_running > 0 do
pause_fn(pause_for)
if timeout ~= nil and uv.uptime() - start_time > timeout then
error "Timeout error from AsyncExecutor.join()"
end
end
end

---Block Neovim until all currently running tasks have completed, waiting at most `timeout` milliseconds
---before raising a timeout error.
---
---This is useful in testing, but in general you want to avoid blocking Neovim.
---
---@param self obsidian.Executor
---@param timeout integer|?
Executor.join = function(self, timeout)
self:_join(timeout, vim.wait)
end

---An async version of `.join()`.
---
---@param self obsidian.Executor
---@param timeout integer|?
Executor.join_async = function(self, timeout)
self:_join(timeout, async.util.sleep)
end

---An Executor that uses coroutines to run user functions concurrently.
---@class obsidian.AsyncExecutor : obsidian.Executor
---@field tasks_running integer
local AsyncExecutor = Executor.new()
M.AsyncExecutor = AsyncExecutor

---@return obsidian.AsyncExecutor
AsyncExecutor.new = function()
local self = setmetatable({}, { __index = AsyncExecutor })
self.tasks_running = 0
return self
end

---Submit a one-off function with a callback to the thread pool.
---
---@param self obsidian.AsyncExecutor
---@param fn function
---@param callback function|?
---@diagnostic disable-next-line: unused-local
AsyncExecutor.submit = function(self, fn, callback, ...)
self.tasks_running = self.tasks_running + 1
local args = { ... }
async.run(function()
return fn(unpack(args))
end, function(...)
self.tasks_running = self.tasks_running - 1
if callback ~= nil then
callback(...)
end
end)
end

---A multi-threaded Executor which uses the Libuv threadpool.
---@class obsidian.ThreadPoolExecutor : obsidian.Executor
---@field tasks_running integer
local ThreadPoolExecutor = Executor.new()
M.ThreadPoolExecutor = ThreadPoolExecutor

---@return obsidian.ThreadPoolExecutor
ThreadPoolExecutor.new = function()
local self = setmetatable({}, { __index = ThreadPoolExecutor })
self.tasks_running = 0
return self
end

---Submit a one-off function with a callback to the thread pool.
---
---@param self obsidian.ThreadPoolExecutor
---@param fn function
---@param callback function|?
---@diagnostic disable-next-line: unused-local
ThreadPoolExecutor.submit = function(self, fn, callback, ...)
self.tasks_running = self.tasks_running + 1
local ctx = uv.new_work(fn, function(...)
self.tasks_running = self.tasks_running - 1
if callback ~= nil then
callback(...)
end
end)
ctx:queue(...)
end

---Represents a file.
---@class obsidian.File
---@field fd userdata
local File = {}
M.File = File

---@param path string
---@return obsidian.File
File.open = function(path)
local self = setmetatable({}, { __index = File })
local err, fd = async.uv.fs_open(path, "r", 438)
assert(not err, err)
self.fd = fd
return self
end

---Close the file.
---@param self obsidian.File
File.close = function(self)
local err = async.uv.fs_close(self.fd)
assert(not err, err)
end

---Get at iterator over lines in the file.
---@param include_new_line_char boolean|?
File.lines = function(self, include_new_line_char)
local offset = 0
local chunk_size = 1024
local buffer = ""
local eof_reached = false

local lines = function()
local idx = string.find(buffer, "[\r\n]")
while idx == nil and not eof_reached do
---@diagnostic disable-next-line: redefined-local
local err, data
err, data = async.uv.fs_read(self.fd, chunk_size, offset)
assert(not err, err)
if string.len(data) == 0 then
eof_reached = true
else
buffer = buffer .. data
offset = offset + string.len(data)
idx = string.find(buffer, "[\r\n]")
end
end

if idx ~= nil then
local line = string.sub(buffer, 1, idx)
buffer = string.sub(buffer, idx + 1)
if include_new_line_char then
return line
else
return string.sub(line, 1, -2)
end
else
return nil
end
end

return lines
end

return M
Loading