Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: improve repl completer #199

Merged
merged 1 commit into from
Nov 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/client/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ macro_rules! register_client {
anyhow::bail!("Unknown client {}", client)
}

pub fn all_models(config: &$crate::config::Config) -> Vec<$crate::client::ModelInfo> {
pub fn list_models(config: &$crate::config::Config) -> Vec<$crate::client::ModelInfo> {
config
.clients
.iter()
Expand Down
2 changes: 1 addition & 1 deletion src/client/model_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ impl ModelInfo {
self
}

pub fn full_name(&self) -> String {
pub fn id(&self) -> String {
format!("{}:{}", self.client, self.name)
}

Expand Down
72 changes: 35 additions & 37 deletions src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use self::role::Role;
use self::session::{Session, TEMP_SESSION_NAME};

use crate::client::{
all_models, create_client_config, list_client_types, ClientConfig, ExtraConfig, Message,
create_client_config, list_client_types, list_models, ClientConfig, ExtraConfig, Message,
ModelInfo, OpenAIClient, SendData,
};
use crate::render::{MarkdownRender, RenderOptions};
Expand Down Expand Up @@ -35,16 +35,6 @@ const ROLES_FILE_NAME: &str = "roles.yaml";
const MESSAGES_FILE_NAME: &str = "messages.md";
const SESSIONS_DIR_NAME: &str = "sessions";

const SET_COMPLETIONS: [&str; 7] = [
".set temperature",
".set save true",
".set save false",
".set highlight true",
".set highlight false",
".set dry_run true",
".set dry_run false",
];

const CLIENTS_FIELD: &str = "clients";

#[derive(Debug, Clone, Deserialize)]
Expand Down Expand Up @@ -311,10 +301,11 @@ impl Config {
}

pub fn set_model(&mut self, value: &str) -> Result<()> {
let models = all_models(self);
let models = list_models(self);
let mut model_info = None;
let value = value.trim_end_matches(':');
if value.contains(':') {
if let Some(model) = models.iter().find(|v| v.full_name() == value) {
if let Some(model) = models.iter().find(|v| v.id() == value) {
model_info = Some(model.clone());
}
} else if let Some(model) = models.iter().find(|v| v.client == value) {
Expand Down Expand Up @@ -345,7 +336,7 @@ impl Config {
.clone()
.map_or_else(|| String::from("no"), |v| v.to_string());
let items = vec![
("model", self.model_info.full_name()),
("model", self.model_info.id()),
("temperature", temperature),
("dry_run", self.dry_run.to_string()),
("save", self.save.to_string()),
Expand Down Expand Up @@ -402,22 +393,27 @@ impl Config {
.unwrap_or_default()
}

pub fn repl_completions(&self) -> Vec<String> {
let mut completion: Vec<String> = self
.roles
.iter()
.map(|v| format!(".role {}", v.name))
pub fn repl_complete(&self, cmd: &str, args: &str) -> Vec<String> {
let possible_values = match cmd {
".role" => self.roles.iter().map(|v| v.name.clone()).collect(),
".model" => list_models(self).into_iter().map(|v| v.id()).collect(),
".session" => self.list_sessions(),
".set" => {
vec![
"temperature ".into(),
format!("save {}", !self.save),
format!("highlight {}", !self.highlight),
format!("dry_run {}", !self.dry_run),
]
}
_ => vec![],
};
let mut possible_values: Vec<String> = possible_values
.into_iter()
.filter(|v| v.starts_with(args))
.collect();

completion.extend(SET_COMPLETIONS.map(std::string::ToString::to_string));
completion.extend(
all_models(self)
.iter()
.map(|v| format!(".model {}", v.full_name())),
);
let sessions = self.list_sessions().unwrap_or_default();
completion.extend(sessions.iter().map(|v| format!(".session {}", v)));
completion
possible_values.sort_unstable();
possible_values
}

pub fn update(&mut self, data: &str) -> Result<()> {
Expand Down Expand Up @@ -541,21 +537,23 @@ impl Config {
Ok(())
}

pub fn list_sessions(&self) -> Result<Vec<String>> {
let sessions_dir = Self::sessions_dir()?;
pub fn list_sessions(&self) -> Vec<String> {
let sessions_dir = match Self::sessions_dir() {
Ok(dir) => dir,
Err(_) => return vec![],
};
match read_dir(&sessions_dir) {
Ok(rd) => {
let mut names = vec![];
for entry in rd {
let entry = entry?;
for entry in rd.flatten() {
let name = entry.file_name();
if let Some(name) = name.to_string_lossy().strip_suffix(".yaml") {
names.push(name.to_string());
}
}
Ok(names)
names
}
Err(_) => Ok(vec![]),
Err(_) => vec![],
}
}

Expand Down Expand Up @@ -665,12 +663,12 @@ impl Config {
let model = match &self.model {
Some(v) => v.clone(),
None => {
let models = all_models(self);
let models = list_models(self);
if models.is_empty() {
bail!("No available model");
}

models[0].full_name()
models[0].id()
}
};
self.set_model(&model)?;
Expand Down
6 changes: 3 additions & 3 deletions src/config/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ impl Session {
pub fn new(name: &str, model_info: ModelInfo, role: Option<Role>) -> Self {
let temperature = role.as_ref().and_then(|v| v.temperature);
Self {
model: model_info.full_name(),
model: model_info.id(),
temperature,
messages: vec![],
name: name.to_string(),
Expand Down Expand Up @@ -103,7 +103,7 @@ impl Session {
items.push(("path", path.to_string()));
}

items.push(("model", self.model_info.full_name()));
items.push(("model", self.model_info.id()));

if let Some(temperature) = self.temperature() {
items.push(("temperature", temperature.to_string()));
Expand Down Expand Up @@ -165,7 +165,7 @@ impl Session {
}

pub fn set_model(&mut self, model_info: ModelInfo) -> Result<()> {
self.model = model_info.full_name();
self.model = model_info.id();
self.model_info = model_info;
Ok(())
}
Expand Down
8 changes: 4 additions & 4 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::config::{Config, GlobalConfig};

use anyhow::Result;
use clap::Parser;
use client::{all_models, init_client};
use client::{init_client, list_models};
use crossbeam::sync::WaitGroup;
use is_terminal::IsTerminal;
use parking_lot::RwLock;
Expand All @@ -36,13 +36,13 @@ fn main() -> Result<()> {
exit(0);
}
if cli.list_models {
for model in all_models(&config.read()) {
println!("{}", model.full_name());
for model in list_models(&config.read()) {
println!("{}", model.id());
}
exit(0);
}
if cli.list_sessions {
let sessions = config.read().list_sessions()?.join("\n");
let sessions = config.read().list_sessions().join("\n");
println!("{sessions}");
exit(0);
}
Expand Down
94 changes: 94 additions & 0 deletions src/repl/completer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
use std::collections::HashMap;

use super::{parse_command, REPL_COMMANDS};

use crate::config::GlobalConfig;

use reedline::{Completer, Span, Suggestion};

impl Completer for ReplCompleter {
fn complete(&mut self, line: &str, pos: usize) -> Vec<Suggestion> {
let mut suggestions = vec![];
if line.len() != pos {
return suggestions;
}
let line = &line[0..pos];
if let Some((cmd, args)) = parse_command(line) {
let commands: Vec<_> = self
.commands
.iter()
.filter(|(cmd_name, _)| match args {
Some(args) => cmd_name.starts_with(&format!("{cmd} {args}")),
None => cmd_name.starts_with(cmd),
})
.collect();

if args.is_some() || line.ends_with(' ') {
let args = args.unwrap_or_default();
let start = line.chars().take_while(|c| *c == ' ').count() + cmd.len() + 1;
let span = Span::new(start, pos);
suggestions.extend(
self.config
.read()
.repl_complete(cmd, args)
.iter()
.map(|name| create_suggestion(name.clone(), None, span)),
)
}

if suggestions.is_empty() {
let start = line.chars().take_while(|c| *c == ' ').count();
let span = Span::new(start, pos);
suggestions.extend(commands.iter().map(|(name, desc)| {
let has_group = self.groups.get(name).map(|v| *v > 1).unwrap_or_default();
let name = if has_group {
name.to_string()
} else {
format!("{name} ")
};
create_suggestion(name, Some(desc.to_string()), span)
}))
}
}
suggestions
}
}

pub struct ReplCompleter {
config: GlobalConfig,
commands: Vec<(&'static str, &'static str)>,
groups: HashMap<&'static str, usize>,
}

impl ReplCompleter {
pub fn new(config: &GlobalConfig) -> Self {
let mut groups = HashMap::new();

let mut commands = REPL_COMMANDS.to_vec();
commands.sort_by(|(a, _), (b, _)| a.cmp(b));

for (name, _) in REPL_COMMANDS.iter() {
if let Some(count) = groups.get(name) {
groups.insert(*name, count + 1);
} else {
groups.insert(*name, 1);
}
}

Self {
config: config.clone(),
commands,
groups,
}
}
}

fn create_suggestion(value: String, description: Option<String>, span: Span) -> Suggestion {
Suggestion {
value,
description,
extra: None,
span,
append_whitespace: false,
}
}
22 changes: 8 additions & 14 deletions src/repl/highlighter.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
use super::REPL_COMMANDS;

use crate::config::GlobalConfig;

use nu_ansi_term::{Color, Style};
use reedline::{Highlighter, StyledText};

pub struct ReplHighlighter {
external_commands: Vec<String>,
config: GlobalConfig,
}

impl ReplHighlighter {
pub fn new(external_commands: Vec<String>, config: GlobalConfig) -> Self {
pub fn new(config: &GlobalConfig) -> Self {
Self {
external_commands,
config,
config: config.clone(),
}
}
}
Expand All @@ -28,17 +28,11 @@ impl Highlighter for ReplHighlighter {

let mut styled_text = StyledText::new();

if self
.external_commands
.clone()
.iter()
.any(|x| line.contains(x))
{
let matches: Vec<&str> = self
.external_commands
if REPL_COMMANDS.iter().any(|(cmd, _)| line.contains(cmd)) {
let matches: Vec<&str> = REPL_COMMANDS
.iter()
.filter(|c| line.contains(*c))
.map(std::ops::Deref::deref)
.filter(|(cmd, _)| line.contains(*cmd))
.map(|(cmd, _)| *cmd)
.collect();
let longest_match = matches.iter().fold(String::new(), |acc, &item| {
if item.len() > acc.len() {
Expand Down
Loading