Skip to content

Commit

Permalink
Add default questions in chat screen.
Browse files Browse the repository at this point in the history
  • Loading branch information
banghuazhao committed Sep 29, 2024
1 parent 7e918cb commit c89a73d
Show file tree
Hide file tree
Showing 13 changed files with 398 additions and 206 deletions.
14 changes: 7 additions & 7 deletions data/lib/data_sources/chat_completion_data_source.dart
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ import 'dart:convert';
import 'package:domain/entities/function_tool.dart';
import 'package:domain/entities/message.dart';
import 'package:http/http.dart' as http;
import '../models/chat_chunk.dart';
import '../utils/api_constants.dart';
import '../utils/network_exceptions.dart';

abstract class ChatCompletionsDataSource {
Stream<String> sendMessages(List<Message> messages, List<FunctionTool> functionTools);
Stream<ChatChunk> sendMessages(List<Message> messages, List<FunctionTool> functionTools);
}

class ChatRemoteDataSourceImpl implements ChatCompletionsDataSource {
Expand All @@ -15,7 +16,7 @@ class ChatRemoteDataSourceImpl implements ChatCompletionsDataSource {
ChatRemoteDataSourceImpl({required this.client});

@override
Stream<String> sendMessages(List<Message> messages,
Stream<ChatChunk> sendMessages(List<Message> messages,
List<FunctionTool> functionTools) async* {
final request = http.Request('POST', Uri.parse(ApiConstants.chatCompletionsEndpoint))
..headers.addAll({
Expand Down Expand Up @@ -58,12 +59,11 @@ class ChatRemoteDataSourceImpl implements ChatCompletionsDataSource {

try {
final data = jsonDecode(jsonString);
final content = data['choices'][0]['delta']['content'] ?? '';
if (content.isNotEmpty) {
yield content;
}
} catch (_) {
final chatChunk = ChatChunk.fromJson(data);
yield chatChunk;
} catch (e) {
// Handle JSON parsing errors
print('Error decoding chat completion JSON: $e. $jsonString');
continue;
}
}
Expand Down
65 changes: 65 additions & 0 deletions data/lib/models/chat_chunk.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import 'package:domain/entities/message.dart';

class ChatChunk {
String? id;
String? object;
int? created;
String? model;
String? systemFingerprint;
List<Choices>? choices;

ChatChunk(
{this.id,
this.object,
this.created,
this.model,
this.systemFingerprint,
this.choices});

ChatChunk.fromJson(Map<String, dynamic> json) {
id = json['id'];
object = json['object'];
created = json['created'];
model = json['model'];
systemFingerprint = json['system_fingerprint'];
if (json['choices'] != null) {
choices = <Choices>[];
json['choices'].forEach((v) {
choices!.add(new Choices.fromJson(v));
});
}
}
}

class Choices {
int? index;
Delta? delta;
String? logprobs;
String? finishReason;

Choices({this.index, this.delta, this.logprobs, this.finishReason});

Choices.fromJson(Map<String, dynamic> json) {
index = json['index'];
delta = json['delta'] != null ? new Delta.fromJson(json['delta']) : null;
logprobs = json['logprobs'];
finishReason = json['finish_reason'];
}
}

class Delta {
String? content;
List<ToolCalls>? toolCalls;

Delta({this.content, this.toolCalls});

Delta.fromJson(Map<String, dynamic> json) {
content = json['content'];
if (json['tool_calls'] != null) {
toolCalls = <ToolCalls>[];
json['tool_calls'].forEach((v) {
toolCalls!.add(ToolCalls.fromJson(v));
});
}
}
}
86 changes: 0 additions & 86 deletions data/lib/models/chat_completion.dart

This file was deleted.

47 changes: 45 additions & 2 deletions data/lib/repositories/chat_repository_impl.dart
Original file line number Diff line number Diff line change
@@ -1,15 +1,58 @@
import 'dart:async';

import 'package:domain/entities/function_tool.dart';
import 'package:domain/entities/message.dart';
import 'package:domain/repositories_abstract/chat_repository.dart';
import '../data_sources/chat_completion_data_source.dart';
import '../models/chat_chunk.dart';

class ChatRepositoryImp implements ChatRepository {
final ChatCompletionsDataSource chatCompletionsDataSource;

ChatRepositoryImp({required this.chatCompletionsDataSource});

@override
Stream<String> sendMessages(List<Message> messages, List<FunctionTool> functionTools) {
return chatCompletionsDataSource.sendMessages(messages, functionTools);
Stream<Message> sendMessages(
List<Message> messages, List<FunctionTool> functionTools) {
// Create a StreamController to accumulate and emit the content as strings
final StreamController<Message> controller = StreamController<Message>();

// Call the original sendMessages method that returns Stream<ChatChunk>
final chatChunks =
chatCompletionsDataSource.sendMessages(messages, functionTools);

Message buffer = Message(role: "assistant");
// Listen to the incoming stream of ChatChunks
chatChunks.listen((ChatChunk chunk) {
// Accumulate or extract the content from the ChatChunk and add it to the stream
final token = chunk.choices?.first.delta?.content;
if (token != null) {
buffer.content = (buffer.content ?? "") + token;
}

final toolCalls = chunk.choices?.first.delta?.toolCalls;
if (toolCalls != null) {
if (buffer.toolCalls == null) {
buffer.toolCalls = toolCalls;
} else {
final newArguments = toolCalls.first.function?.arguments;
if (newArguments != null) {
final currentArg =
buffer.toolCalls?.first.function?.arguments ?? "";
buffer.toolCalls?.first.function?.arguments =
currentArg + newArguments;
}
}
}

controller.add(buffer); // Emit the content to the stream
}, onError: (error) {
controller.addError(error); // Handle errors
}, onDone: () {
controller.close(); // Close the stream when done
});

// Return the stream of accumulated content as strings
return controller.stream;
}
}
80 changes: 68 additions & 12 deletions domain/lib/entities/message.dart
Original file line number Diff line number Diff line change
@@ -1,22 +1,78 @@
class Message {
final String role;
String content;
String? role;
String? content;
List<ToolCalls>? toolCalls;
String? tool_call_id;

Message({required this.role, required this.content});
Message({this.role, this.content, this.toolCalls, this.tool_call_id});

// Factory constructor for creating a new Message instance from JSON
factory Message.fromJson(Map<String, dynamic> json) {
return Message(
role: json['role'],
content: json['content'],
);
Message.fromJson(Map<String, dynamic> json) {
role = json['role'] ?? 'user';
content = json['content'];
if (json['tool_calls'] != null) {
toolCalls = <ToolCalls>[];
json['tool_calls'].forEach((v) {
toolCalls!.add(ToolCalls.fromJson(v));
});
}
tool_call_id = json['tool_call_id'];
}

// Method for converting a Message instance to JSON format
Map<String, dynamic> toJson() {
return {
'role': role,
'content': content,
};
final Map<String, dynamic> data = <String, dynamic>{};
data['role'] = role;
data['content'] = content;
if (toolCalls != null) {
data['tool_calls'] = toolCalls!.map((v) => v.toJson()).toList();
}
data['tool_call_id'] = tool_call_id;
return data;
}
}

class ToolCalls {
String? id;
String? type;
FunctionCall? function;

ToolCalls({this.id, this.type, this.function});

ToolCalls.fromJson(Map<String, dynamic> json) {
id = json['id'];
type = json['type'];
function = json['function'] != null
? FunctionCall.fromJson(json['function'])
: null;
}

Map<String, dynamic> toJson() {
final Map<String, dynamic> data = <String, dynamic>{};
data['id'] = id;
data['type'] = type;
if (function != null) {
data['function'] = function!.toJson();
}
return data;
}
}

class FunctionCall {
String? name;
String? arguments;

FunctionCall({this.name, this.arguments});

FunctionCall.fromJson(Map<String, dynamic> json) {
name = json['name'];
arguments = json['arguments'];
}

Map<String, dynamic> toJson() {
final Map<String, dynamic> data = <String, dynamic>{};
data['name'] = name;
data['arguments'] = arguments;
return data;
}
}
2 changes: 1 addition & 1 deletion domain/lib/repositories_abstract/chat_repository.dart
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ import '../entities/function_tool.dart';
import '../entities/message.dart';

abstract class ChatRepository {
Stream<String> sendMessages(List<Message> messages,
Stream<Message> sendMessages(List<Message> messages,
List<FunctionTool> functionTools);
}
6 changes: 2 additions & 4 deletions domain/lib/usecases/chat_session_usecase.dart
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,11 @@ class ChatSessionUseCase {
}

// Add a new method to update the last assistant message
void updateLastAssistantMessage(ChatSession session, String token) {
void updateLastAssistantMessage(ChatSession session, Message message) {
// Find the last message that is from the assistant
print(session.messages.last.content);
for (var i = session.messages.length - 1; i >= 0; i--) {
if (session.messages[i].role == 'assistant') {
final previousContent = session.messages[i].content;
session.messages[i] = Message(role: 'assistant', content: previousContent + token);
session.messages[i] = message;
break;
}
}
Expand Down
2 changes: 1 addition & 1 deletion domain/lib/usecases/chat_usecase.dart
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class ChatUseCase {
content:
"You are an expert assistant specialized in composite materials. Your role is to provide accurate and detailed answers to questions related to composite material properties, design, calculations, and analysis.");

Stream<String> sendMessages(List<Message> messages) {
Stream<Message> sendMessages(List<Message> messages) {
final chatHistory = [systemMessage] + messages;
final functionTools = functionToolsRepository.getAllFunctionTools();
return chatRepository.sendMessages(
Expand Down
Loading

0 comments on commit c89a73d

Please sign in to comment.