Skip to content

Commit

Permalink
feat(shared_instance): java crac
Browse files Browse the repository at this point in the history
  • Loading branch information
ActivePeter committed Jun 13, 2024
1 parent d563329 commit 5a804c2
Show file tree
Hide file tree
Showing 25 changed files with 1,221 additions and 362 deletions.
4 changes: 4 additions & 0 deletions demos/_java_serverless_lib/core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@
<groupId>com.google.protobuf</groupId>
<artifactId>protobuf-java</artifactId>
</dependency>
<dependency>
<groupId>org.crac</groupId>
<artifactId>crac</artifactId>
</dependency>
</dependencies>

<build>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,9 @@ public UdsBackend udsBackend() {
public RpcHandleOwner rpcHandleOwner() {
return new RpcHandleOwner();
}

@Bean
public CracManager cracManager() {
return new CracManager();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import org.springframework.boot.DefaultApplicationArguments;
import org.springframework.core.env.Environment;

@Component
public class BootArgCheck implements CommandLineRunner
// DisposableBean
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package io.serverless_lib;

import org.springframework.stereotype.Component;
import org.springframework.boot.CommandLineRunner;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.context.event.EventListener;
import org.springframework.boot.ApplicationArguments;
import org.springframework.boot.DefaultApplicationArguments;
import org.springframework.core.env.Environment;
import process_rpc_proto.ProcessRpcProto.UpdateCheckpoint;

public class CracManager implements org.crac.Resource
// DisposableBean
{
boolean checkpointed = false;

@Autowired
UdsBackend uds;

@EventListener
public void bootArgCheckOk(BootArgCheckOkEvent e) {
// /-----------------------\
// /-------------------------------------- | # Agent |
// | \-----------------------/
// /----------------------\
// | # CracManager |
// | // first time init |
// | checkpointed = false |
// \----------------------/
// |
// | bootArgCheckOk
// | if (!checkpointed) ----- uds call -----> |
// |
// | <----------------------------------------/ take snapshot
// |
// | beforeCheckpoint
// \ checkpointed = true--------------------> |
// |
// /----------------------\ <-------------------------/ restart by crac
// | # CracManager |
// | // criu init |
// | checkpointed = true |
// \----------------------/

// register crac
org.crac.Core.getGlobalContext().register(this);

if (!checkpointed) {
System.out.println("CracManager requires for snapshot.");
uds.send(new UdsPack(UpdateCheckpoint.newBuilder().build(),0));
// TODO: change to rpc
checkpointed = true;
}
}

@Override
public void beforeCheckpoint(org.crac.Context<? extends org.crac.Resource> context) throws Exception {
// Prepare for checkpoint
System.out.println("CracManager before checkpoint.");
// close the uds
uds.close();
}

@Override
public void afterRestore(org.crac.Context<? extends org.crac.Resource> context) throws Exception {
// Handle after restore
System.out.println("CracManager after restore.");
uds.start();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ void createHandleMeta(T service) {

// protected abstract Object rpcDispatch(String func, Object requestObj);

public ByteBuf handleRpc(String func, ByteBuf req) {
public String handleRpc(String func, String argStr) {
// Get meta
RpcFuncMeta rpcFunc = handleMeta.get(func);
if (rpcFunc == null) {
Expand All @@ -82,10 +82,7 @@ public ByteBuf handleRpc(String func, ByteBuf req) {
// Deserialize request
Object[] params;
try {
byte[] reqBytes = new byte[req.readableBytes()];
req.readBytes(reqBytes);
String reqJson = new String(reqBytes);
JsonObject jsonObject = gson.fromJson(reqJson, JsonObject.class);
JsonObject jsonObject = gson.fromJson(argStr, JsonObject.class);

// Extract parameters from JSON and convert to appropriate types
Map<String, Class<?>> paramMeta = rpcFunc.getParameters();
Expand All @@ -109,16 +106,16 @@ public ByteBuf handleRpc(String func, ByteBuf req) {
throw new RuntimeException("Failed to invoke method", e);
}

// Serialize response
ByteBuf responseBuf;
try {
String respJson = gson.toJson(responseObj);
byte[] respBytes = respJson.getBytes();
responseBuf = Unpooled.wrappedBuffer(respBytes);
} catch (Exception e) {
throw new RuntimeException("Failed to serialize response", e);
}

return responseBuf;
// // Serialize response
// ByteBuf responseBuf;
// try {
// String respJson = gson.toJson(responseObj);
// byte[] respBytes = respJson.getBytes();
// responseBuf = Unpooled.wrappedBuffer(respBytes);
// } catch (Exception e) {
// throw new RuntimeException("Failed to serialize response", e);
// }

return gson.toJson(responseObj);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.List;

import javax.annotation.PostConstruct;
import process_rpc_proto.ProcessRpcProto.FuncStarted;
import process_rpc_proto.ProcessRpcProto.AppStarted;
import process_rpc_proto.ProcessRpcProto.FuncCallReq;
import process_rpc_proto.ProcessRpcProto.FuncCallResp;
import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
Expand All @@ -16,6 +19,9 @@
import io.netty.channel.epoll.EpollDomainSocketChannel;
import io.netty.channel.epoll.EpollEventLoopGroup;
import io.netty.channel.unix.DomainSocketAddress;
import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import io.netty.handler.codec.LengthFieldPrepender;
import io.netty.handler.codec.http.HttpClientCodec;
import io.netty.handler.codec.http.HttpContentDecompressor;
import io.netty.handler.codec.protobuf.ProtobufDecoder;
Expand All @@ -32,6 +38,8 @@
import org.springframework.context.event.EventListener;
import org.springframework.boot.ApplicationArguments;
import org.springframework.boot.DefaultApplicationArguments;
import java.util.ArrayList;
import java.util.concurrent.locks.ReentrantLock;

public class UdsBackend
// DisposableBean
Expand All @@ -42,13 +50,72 @@ public class UdsBackend
@Autowired
RpcHandleOwner rpcHandleOwner;

Channel channel = null;

String agentSock="";

String httpPort="";

private final ReentrantLock sendlock = new ReentrantLock();
List<UdsPack> waitingPacks=new ArrayList<>();

@EventListener
public void bootArgCheckOk(BootArgCheckOkEvent e) {
this.agentSock = e.agentSock;
this.httpPort = e.httpPort;
start();
}

public void start(){
netty_thread = new Thread(() -> {
UnixChannelHandle.start(Paths.get(e.agentSock), e.httpPort, rpcHandleOwner);
UnixChannelHandle.start(Paths.get(agentSock), httpPort, rpcHandleOwner, this);
});
netty_thread.start();
}


public void send(UdsPack pack){
sendlock.lock();
if(channel==null){
// chennel 读到null后,还没连接{加入队列} else {也可能连接了,if {还没消费掉队列,} else {队列已经消费!!!泄露}}
// 因此需要锁,保证channel 为null时,消息一定加到队列
System.out.println("Channel is not ready, packs will be sent later.");
waitingPacks.add(pack);

sendlock.unlock();
return;
}
sendlock.unlock();

System.out.println("Sending pack, packid:"+pack.id+", taskid:"+pack.taskId);
channel.writeAndFlush(pack.encode());
}

public void setUpChannel(Channel channel){
sendlock.lock();
this.channel=channel;
for(UdsPack pack:waitingPacks){
System.out.println("Sending pended pack, packid:"+pack.id+", taskid:"+pack.taskId);
send(pack);
}
waitingPacks.clear();
sendlock.unlock();
}

public void close(){
try{

sendlock.lock();
channel.close().sync();
netty_thread.join();
channel=null;
sendlock.unlock();
}catch (Exception e){
System.out.println("close uds with err");
e.printStackTrace();
sendlock.unlock();
}
}
}

class ByteBufInputStream extends InputStream {
Expand Down Expand Up @@ -83,6 +150,16 @@ public int available() throws IOException {
}
}

class RpcPack {
public int taskId;
public ByteBuf packData;

public RpcPack(int taskId, ByteBuf packData) {
this.taskId = taskId;
this.packData = packData;
}
}

class UnixChannelHandle {
static void waitingForSockFile(Path sock_path) {
System.out.println("Current directory: " + Paths.get(".").toAbsolutePath().toString());
Expand All @@ -105,7 +182,7 @@ static void waitingForSockFile(Path sock_path) {
}
}

static void start(Path sock_path, String httpPort, RpcHandleOwner rpcHandleOwner) {
static void start(Path sock_path, String httpPort, RpcHandleOwner rpcHandleOwner, UdsBackend udsHandle) {
io.netty.bootstrap.Bootstrap bootstrap = new io.netty.bootstrap.Bootstrap();
final EpollEventLoopGroup epollEventLoopGroup = new EpollEventLoopGroup();
try {
Expand All @@ -115,27 +192,74 @@ static void start(Path sock_path, String httpPort, RpcHandleOwner rpcHandleOwner
@Override
public void initChannel(UnixChannel ch) throws Exception {
ch.pipeline()
.addLast(new SimpleChannelInboundHandler<ByteBuf>() {
.addLast(new ByteToMessageDecoder() {
@Override
protected void channelRead0(ChannelHandlerContext ctx, ByteBuf msg)
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out)
throws Exception {
// 确保有足够的字节来读取长度字段
if (in.readableBytes() < 4) {
return;
}

// 标记当前的读索引
in.markReaderIndex();

// 读取长度字段
int length = in.readInt();
int taskId = in.readInt();

// 确保有足够的字节来读取数据
if (in.readableBytes() < length) {
// 重置读索引
in.resetReaderIndex();
return;
}

// 读取数据
ByteBuf frame = in.readBytes(length);
out.add(new RpcPack(taskId, frame));
}
})
.addLast(new SimpleChannelInboundHandler<RpcPack>() {
@Override
protected void channelRead0(ChannelHandlerContext ctx, RpcPack msg)
throws Exception {
System.out.println(
"Received message from server: " + msg.packData.readableBytes());
// read four bytre for id
ByteBufInputStream stream = new ByteBufInputStream(msg.packData);

FuncCallReq funcCallReq = FuncCallReq
.parseFrom(new ByteBufInputStream(msg));
.parseFrom(stream);

// Handle the deserialized message
String func = funcCallReq.getFunc();
String argStr = funcCallReq.getArgStr();

// 需要一个线程池来处理消息
rpcHandleOwner.rpcHandle.handleRpc(func, msg);
try {
String resStr = rpcHandleOwner.rpcHandle.handleRpc(func, argStr);
FuncCallResp resp = FuncCallResp.newBuilder().setRetStr(resStr)
.build();

// byte[] data = resp.toByteArray();
// ByteBuf buffer = Unpooled.buffer(8 + data.length);
// buffer.writeInt(data.length);
// buffer.writeInt(msg.taskId);
// buffer.writeBytes(data);
ctx.writeAndFlush(new UdsPack(resp,msg.taskId).encode());
System.out.println("Response sent.");
} catch (Exception e) {
e.printStackTrace();
}
}

@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
System.out.println("Channel is active");

// Create AuthHeader message
FuncStarted commu = FuncStarted.newBuilder().setFnid("stock-mng")
AppStarted commu = AppStarted.newBuilder().setAppid("stock-mng")
.setHttpPort(httpPort).build();

// Serialize the message
Expand Down Expand Up @@ -163,9 +287,9 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
});
waitingForSockFile(sock_path);
// System.out.println("agent's sock is ready");
Channel channel = bootstrap.connect(new DomainSocketAddress(sock_path.toAbsolutePath().toString())).sync()
Channel channel = bootstrap.connect(new DomainSocketAddress(sock_path.toString())).sync()
.channel();

udsHandle.setUpChannel(channel);
channel.closeFuture().sync();

// final FullHttpRequest request = new
Expand Down
Loading

0 comments on commit 5a804c2

Please sign in to comment.