Skip to content

Commit

Permalink
refactor: improve repl completer (#199)
Browse files Browse the repository at this point in the history
  • Loading branch information
sigoden authored Nov 2, 2023
1 parent 7c68417 commit 652b515
Show file tree
Hide file tree
Showing 8 changed files with 191 additions and 105 deletions.
2 changes: 1 addition & 1 deletion src/client/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ macro_rules! register_client {
anyhow::bail!("Unknown client {}", client)
}

pub fn all_models(config: &$crate::config::Config) -> Vec<$crate::client::ModelInfo> {
pub fn list_models(config: &$crate::config::Config) -> Vec<$crate::client::ModelInfo> {
config
.clients
.iter()
Expand Down
2 changes: 1 addition & 1 deletion src/client/model_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ impl ModelInfo {
self
}

pub fn full_name(&self) -> String {
pub fn id(&self) -> String {
format!("{}:{}", self.client, self.name)
}

Expand Down
72 changes: 35 additions & 37 deletions src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use self::role::Role;
use self::session::{Session, TEMP_SESSION_NAME};

use crate::client::{
all_models, create_client_config, list_client_types, ClientConfig, ExtraConfig, Message,
create_client_config, list_client_types, list_models, ClientConfig, ExtraConfig, Message,
ModelInfo, OpenAIClient, SendData,
};
use crate::render::{MarkdownRender, RenderOptions};
Expand Down Expand Up @@ -35,16 +35,6 @@ const ROLES_FILE_NAME: &str = "roles.yaml";
const MESSAGES_FILE_NAME: &str = "messages.md";
const SESSIONS_DIR_NAME: &str = "sessions";

const SET_COMPLETIONS: [&str; 7] = [
".set temperature",
".set save true",
".set save false",
".set highlight true",
".set highlight false",
".set dry_run true",
".set dry_run false",
];

const CLIENTS_FIELD: &str = "clients";

#[derive(Debug, Clone, Deserialize)]
Expand Down Expand Up @@ -311,10 +301,11 @@ impl Config {
}

pub fn set_model(&mut self, value: &str) -> Result<()> {
let models = all_models(self);
let models = list_models(self);
let mut model_info = None;
let value = value.trim_end_matches(':');
if value.contains(':') {
if let Some(model) = models.iter().find(|v| v.full_name() == value) {
if let Some(model) = models.iter().find(|v| v.id() == value) {
model_info = Some(model.clone());
}
} else if let Some(model) = models.iter().find(|v| v.client == value) {
Expand Down Expand Up @@ -345,7 +336,7 @@ impl Config {
.clone()
.map_or_else(|| String::from("no"), |v| v.to_string());
let items = vec![
("model", self.model_info.full_name()),
("model", self.model_info.id()),
("temperature", temperature),
("dry_run", self.dry_run.to_string()),
("save", self.save.to_string()),
Expand Down Expand Up @@ -402,22 +393,27 @@ impl Config {
.unwrap_or_default()
}

pub fn repl_completions(&self) -> Vec<String> {
let mut completion: Vec<String> = self
.roles
.iter()
.map(|v| format!(".role {}", v.name))
pub fn repl_complete(&self, cmd: &str, args: &str) -> Vec<String> {
let possible_values = match cmd {
".role" => self.roles.iter().map(|v| v.name.clone()).collect(),
".model" => list_models(self).into_iter().map(|v| v.id()).collect(),
".session" => self.list_sessions(),
".set" => {
vec![
"temperature ".into(),
format!("save {}", !self.save),
format!("highlight {}", !self.highlight),
format!("dry_run {}", !self.dry_run),
]
}
_ => vec![],
};
let mut possible_values: Vec<String> = possible_values
.into_iter()
.filter(|v| v.starts_with(args))
.collect();

completion.extend(SET_COMPLETIONS.map(std::string::ToString::to_string));
completion.extend(
all_models(self)
.iter()
.map(|v| format!(".model {}", v.full_name())),
);
let sessions = self.list_sessions().unwrap_or_default();
completion.extend(sessions.iter().map(|v| format!(".session {}", v)));
completion
possible_values.sort_unstable();
possible_values
}

pub fn update(&mut self, data: &str) -> Result<()> {
Expand Down Expand Up @@ -541,21 +537,23 @@ impl Config {
Ok(())
}

pub fn list_sessions(&self) -> Result<Vec<String>> {
let sessions_dir = Self::sessions_dir()?;
pub fn list_sessions(&self) -> Vec<String> {
let sessions_dir = match Self::sessions_dir() {
Ok(dir) => dir,
Err(_) => return vec![],
};
match read_dir(&sessions_dir) {
Ok(rd) => {
let mut names = vec![];
for entry in rd {
let entry = entry?;
for entry in rd.flatten() {
let name = entry.file_name();
if let Some(name) = name.to_string_lossy().strip_suffix(".yaml") {
names.push(name.to_string());
}
}
Ok(names)
names
}
Err(_) => Ok(vec![]),
Err(_) => vec![],
}
}

Expand Down Expand Up @@ -665,12 +663,12 @@ impl Config {
let model = match &self.model {
Some(v) => v.clone(),
None => {
let models = all_models(self);
let models = list_models(self);
if models.is_empty() {
bail!("No available model");
}

models[0].full_name()
models[0].id()
}
};
self.set_model(&model)?;
Expand Down
6 changes: 3 additions & 3 deletions src/config/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ impl Session {
pub fn new(name: &str, model_info: ModelInfo, role: Option<Role>) -> Self {
let temperature = role.as_ref().and_then(|v| v.temperature);
Self {
model: model_info.full_name(),
model: model_info.id(),
temperature,
messages: vec![],
name: name.to_string(),
Expand Down Expand Up @@ -103,7 +103,7 @@ impl Session {
items.push(("path", path.to_string()));
}

items.push(("model", self.model_info.full_name()));
items.push(("model", self.model_info.id()));

if let Some(temperature) = self.temperature() {
items.push(("temperature", temperature.to_string()));
Expand Down Expand Up @@ -165,7 +165,7 @@ impl Session {
}

pub fn set_model(&mut self, model_info: ModelInfo) -> Result<()> {
self.model = model_info.full_name();
self.model = model_info.id();
self.model_info = model_info;
Ok(())
}
Expand Down
8 changes: 4 additions & 4 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::config::{Config, GlobalConfig};

use anyhow::Result;
use clap::Parser;
use client::{all_models, init_client};
use client::{init_client, list_models};
use crossbeam::sync::WaitGroup;
use is_terminal::IsTerminal;
use parking_lot::RwLock;
Expand All @@ -36,13 +36,13 @@ fn main() -> Result<()> {
exit(0);
}
if cli.list_models {
for model in all_models(&config.read()) {
println!("{}", model.full_name());
for model in list_models(&config.read()) {
println!("{}", model.id());
}
exit(0);
}
if cli.list_sessions {
let sessions = config.read().list_sessions()?.join("\n");
let sessions = config.read().list_sessions().join("\n");
println!("{sessions}");
exit(0);
}
Expand Down
94 changes: 94 additions & 0 deletions src/repl/completer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
use std::collections::HashMap;

use super::{parse_command, REPL_COMMANDS};

use crate::config::GlobalConfig;

use reedline::{Completer, Span, Suggestion};

impl Completer for ReplCompleter {
fn complete(&mut self, line: &str, pos: usize) -> Vec<Suggestion> {
let mut suggestions = vec![];
if line.len() != pos {
return suggestions;
}
let line = &line[0..pos];
if let Some((cmd, args)) = parse_command(line) {
let commands: Vec<_> = self
.commands
.iter()
.filter(|(cmd_name, _)| match args {
Some(args) => cmd_name.starts_with(&format!("{cmd} {args}")),
None => cmd_name.starts_with(cmd),
})
.collect();

if args.is_some() || line.ends_with(' ') {
let args = args.unwrap_or_default();
let start = line.chars().take_while(|c| *c == ' ').count() + cmd.len() + 1;
let span = Span::new(start, pos);
suggestions.extend(
self.config
.read()
.repl_complete(cmd, args)
.iter()
.map(|name| create_suggestion(name.clone(), None, span)),
)
}

if suggestions.is_empty() {
let start = line.chars().take_while(|c| *c == ' ').count();
let span = Span::new(start, pos);
suggestions.extend(commands.iter().map(|(name, desc)| {
let has_group = self.groups.get(name).map(|v| *v > 1).unwrap_or_default();
let name = if has_group {
name.to_string()
} else {
format!("{name} ")
};
create_suggestion(name, Some(desc.to_string()), span)
}))
}
}
suggestions
}
}

pub struct ReplCompleter {
config: GlobalConfig,
commands: Vec<(&'static str, &'static str)>,
groups: HashMap<&'static str, usize>,
}

impl ReplCompleter {
pub fn new(config: &GlobalConfig) -> Self {
let mut groups = HashMap::new();

let mut commands = REPL_COMMANDS.to_vec();
commands.sort_by(|(a, _), (b, _)| a.cmp(b));

for (name, _) in REPL_COMMANDS.iter() {
if let Some(count) = groups.get(name) {
groups.insert(*name, count + 1);
} else {
groups.insert(*name, 1);
}
}

Self {
config: config.clone(),
commands,
groups,
}
}
}

fn create_suggestion(value: String, description: Option<String>, span: Span) -> Suggestion {
Suggestion {
value,
description,
extra: None,
span,
append_whitespace: false,
}
}
22 changes: 8 additions & 14 deletions src/repl/highlighter.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
use super::REPL_COMMANDS;

use crate::config::GlobalConfig;

use nu_ansi_term::{Color, Style};
use reedline::{Highlighter, StyledText};

pub struct ReplHighlighter {
external_commands: Vec<String>,
config: GlobalConfig,
}

impl ReplHighlighter {
pub fn new(external_commands: Vec<String>, config: GlobalConfig) -> Self {
pub fn new(config: &GlobalConfig) -> Self {
Self {
external_commands,
config,
config: config.clone(),
}
}
}
Expand All @@ -28,17 +28,11 @@ impl Highlighter for ReplHighlighter {

let mut styled_text = StyledText::new();

if self
.external_commands
.clone()
.iter()
.any(|x| line.contains(x))
{
let matches: Vec<&str> = self
.external_commands
if REPL_COMMANDS.iter().any(|(cmd, _)| line.contains(cmd)) {
let matches: Vec<&str> = REPL_COMMANDS
.iter()
.filter(|c| line.contains(*c))
.map(std::ops::Deref::deref)
.filter(|(cmd, _)| line.contains(*cmd))
.map(|(cmd, _)| *cmd)
.collect();
let longest_match = matches.iter().fold(String::new(), |acc, &item| {
if item.len() > acc.len() {
Expand Down
Loading

0 comments on commit 652b515

Please sign in to comment.