Skip to content

Commit

Permalink
Merge pull request #708 from AsakusaRinne/llama3_support
Browse files Browse the repository at this point in the history
Add LLaMA3 chat session example.
  • Loading branch information
AsakusaRinne authored Apr 29, 2024
2 parents 4c078a7 + 175b25d commit 98909dc
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 2 deletions.
1 change: 1 addition & 0 deletions LLama.Examples/ExampleRunner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ public class ExampleRunner
{
private static readonly Dictionary<string, Func<Task>> Examples = new()
{
{ "Chat Session: LLama3", LLama3ChatSession.Run },
{ "Chat Session: History", ChatSessionWithHistory.Run },
{ "Chat Session: Role names", ChatSessionWithRoleName.Run },
{ "Chat Session: Role names stripped", ChatSessionStripRoleName.Run },
Expand Down
126 changes: 126 additions & 0 deletions LLama.Examples/Examples/LLama3ChatSession.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
using LLama.Abstractions;
using LLama.Common;

namespace LLama.Examples.Examples;

// When using chatsession, it's a common case that you want to strip the role names
// rather than display them. This example shows how to use transforms to strip them.
public class LLama3ChatSession
{
public static async Task Run()
{
string modelPath = UserSettings.GetModelPath();

var parameters = new ModelParams(modelPath)
{
Seed = 1337,
GpuLayerCount = 10
};
using var model = LLamaWeights.LoadFromFile(parameters);
using var context = model.CreateContext(parameters);
var executor = new InteractiveExecutor(context);

var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json");
ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory();

ChatSession session = new(executor, chatHistory);
session.WithHistoryTransform(new LLama3HistoryTransform());
session.WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(
new string[] { "User:", "Assistant:", "" },
redundancyLength: 5));

InferenceParams inferenceParams = new InferenceParams()
{
Temperature = 0.6f,
AntiPrompts = new List<string> { "User:" }
};

Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("The chat session has started.");

// show the prompt
Console.ForegroundColor = ConsoleColor.Green;
string userInput = Console.ReadLine() ?? "";

while (userInput != "exit")
{
await foreach (
var text
in session.ChatAsync(
new ChatHistory.Message(AuthorRole.User, userInput),
inferenceParams))
{
Console.ForegroundColor = ConsoleColor.White;
Console.Write(text);
}
Console.WriteLine();

Console.ForegroundColor = ConsoleColor.Green;
userInput = Console.ReadLine() ?? "";

Console.ForegroundColor = ConsoleColor.White;
}
}

class LLama3HistoryTransform : IHistoryTransform
{
/// <summary>
/// Convert a ChatHistory instance to plain text.
/// </summary>
/// <param name="history">The ChatHistory instance</param>
/// <returns></returns>
public string HistoryToText(ChatHistory history)
{
string res = Bos;
foreach (var message in history.Messages)
{
res += EncodeMessage(message);
}
res += EncodeHeader(new ChatHistory.Message(AuthorRole.Assistant, ""));
return res;
}

private string EncodeHeader(ChatHistory.Message message)
{
string res = StartHeaderId;
res += message.AuthorRole.ToString();
res += EndHeaderId;
res += "\n\n";
return res;
}

private string EncodeMessage(ChatHistory.Message message)
{
string res = EncodeHeader(message);
res += message.Content;
res += EndofTurn;
return res;
}

/// <summary>
/// Converts plain text to a ChatHistory instance.
/// </summary>
/// <param name="role">The role for the author.</param>
/// <param name="text">The chat history as plain text.</param>
/// <returns>The updated history.</returns>
public ChatHistory TextToHistory(AuthorRole role, string text)
{
return new ChatHistory(new ChatHistory.Message[] { new ChatHistory.Message(role, text) });
}

/// <summary>
/// Copy the transform.
/// </summary>
/// <returns></returns>
public IHistoryTransform Clone()
{
return new LLama3HistoryTransform();
}

private const string StartHeaderId = "<|start_header_id|>";
private const string EndHeaderId = "<|end_header_id|>";
private const string Bos = "<|begin_of_text|>";
private const string Eos = "<|end_of_text|>";
private const string EndofTurn = "<|eot_id|>";
}
}
8 changes: 6 additions & 2 deletions LLama/LLamaTransforms.cs
Original file line number Diff line number Diff line change
Expand Up @@ -235,15 +235,19 @@ public async IAsyncEnumerable<string> TransformAsync(IAsyncEnumerable<string> to
var current = string.Join("", window);
if (_keywords.Any(x => current.Contains(x)))
{
var matchedKeyword = _keywords.First(x => current.Contains(x));
var matchedKeywords = _keywords.Where(x => current.Contains(x));
int total = window.Count;
for (int i = 0; i < total; i++)
{
window.Dequeue();
}
if (!_removeAllMatchedTokens)
{
yield return current.Replace(matchedKeyword, "");
foreach(var keyword in matchedKeywords)
{
current = current.Replace(keyword, "");
}
yield return current;
}
}
if (current.Length >= _maxKeywordLength)
Expand Down

0 comments on commit 98909dc

Please sign in to comment.