Skip to content

Commit

Permalink
Adds chatbot example for llama.cpp engine (#413)
Browse files Browse the repository at this point in the history
  • Loading branch information
bryanktliu authored Jan 15, 2024
1 parent 2d8ea6f commit e40c28b
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 2 deletions.
2 changes: 1 addition & 1 deletion gradle.properties
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ systemProp.org.gradle.internal.http.connectionTimeout=60000
# FIXME: Workaround gradle publish issue: https://github.com/gradle/gradle/issues/11308
systemProp.org.gradle.internal.publish.checksums.insecure=true

djl_version=0.25.0
djl_version=0.26.0
commons_cli_version=1.5.0
log4j_slf4j_version=2.19.0
slf4j_simple_version=1.7.36
Expand Down
4 changes: 3 additions & 1 deletion huggingface/nlp/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,18 @@ dependencies {
implementation platform("ai.djl:bom:${djl_version}")
implementation "ai.djl:api"
implementation "ai.djl.huggingface:tokenizers"
implementation "ai.djl.llama:llama"

runtimeOnly "ai.djl.pytorch:pytorch-engine"
implementation "org.slf4j:slf4j-simple:${slf4j_simple_version}"
}

application {
mainClass = "com.examples.QuestionAnswering"
mainClass = System.getProperty("main", "com.examples.QuestionAnswering")
}

run {
standardInput = System.in
systemProperties System.getProperties()
systemProperties.remove("user.dir")
systemProperty("file.encoding", "UTF-8")
Expand Down
102 changes: 102 additions & 0 deletions huggingface/nlp/src/main/java/com/examples/Chatbot.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/*
* Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package com.examples;

import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.llama.engine.LlamaInput;
import ai.djl.llama.engine.LlamaTranslatorFactory;
import ai.djl.llama.jni.Token;
import ai.djl.llama.jni.TokenIterator;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.util.Locale;
import java.util.Set;

public class Chatbot {

public static void main(String[] args) throws ModelException, IOException, TranslateException {
String modelId;
String quantMethod;
if (args.length > 0) {
modelId = args[0];
if (args.length > 1) {
quantMethod = args[1];
} else {
quantMethod = "Q4_K_M";
}
} else {
// modelId = "TheBloke/Mistral-7B-Instruct-v0.2-GGUF";
// quantMethod = "Q4_K_M";
modelId = "TinyLlama/TinyLlama-1.1B-Chat-v0.6";
quantMethod = "Q4_0";
}
System.out.println("Using model: " + modelId);

String url = "djl://ai.djl.huggingface.gguf/" + modelId + "/0.0.1/" + quantMethod;
Criteria<LlamaInput, TokenIterator> criteria =
Criteria.builder()
.setTypes(LlamaInput.class, TokenIterator.class)
.optModelUrls(url)
.optEngine("Llama")
.optOption("number_gpu_layers", "43")
.optTranslatorFactory(new LlamaTranslatorFactory())
.optProgress(new ProgressBar())
.build();

String system =
"This is demo for DJL Llama.cpp engine.\n\n"
+ "Llama: Hello. How may I help you today?";

LlamaInput.Parameters param = new LlamaInput.Parameters();
param.setTemperature(0.7f);
param.setPenalizeNl(true);
param.setMirostat(2);
param.setAntiPrompt(new String[] {"User: "});

LlamaInput in = new LlamaInput();
in.setParameters(param);

BufferedReader reader =
new BufferedReader(new InputStreamReader(System.in, StandardCharsets.UTF_8));
try (ZooModel<LlamaInput, TokenIterator> model = criteria.loadModel();
Predictor<LlamaInput, TokenIterator> predictor = model.newPredictor()) {
System.out.print(system);
StringBuilder prompt = new StringBuilder(system);
Set<String> exitWords = Set.of("exit", "bye", "quit");
while (true) {
System.out.print("\nUser: ");
String input = reader.readLine().trim();
if (exitWords.contains(input.toLowerCase(Locale.ROOT))) {
break;
}
System.out.print("Llama: ");
prompt.append("\nUser: ").append(input).append("\nLlama: ");
in.setInputs(prompt.toString());
TokenIterator it = predictor.predict(in);
while (it.hasNext()) {
Token token = it.next();
System.out.print(token.getText());
prompt.append(token.getText());
}
}
}
}
}

0 comments on commit e40c28b

Please sign in to comment.