代码沙箱优化模版方法模式
上次我们代码沙箱的docker实现和原生实现
我们可以使用模版方法设计模式去优化我们的代码
我们原生的java实现代码沙箱和docker实现代码沙箱是有更多重复点的
比如说把文件 收集文件 进行校验
我们要用模版方法设计模式
定义一套通用的执行流程 让子类去负责每个具体步骤的执行流程
模版方法的应用场景 使用与有规范的流程 并且执行流程可以复用
作用是 答复节约重复代码量 便于项目扩展 更好的去维护
抽象出具体的流程
先复制具体的实现类
再把代码从完整的方法抽离成一个一个的子写法
想把代码保存为文件
/**
* 1. 把用户的代码保存为文件
* @param code 用户代码
* @return
*/
public File saveCodeToFile(String code) {
String userDir = System.getProperty("user.dir");
String globalCodePathName = userDir + File.separator + GLOBAL_CODE_DIR_NAME;
// 判断全局代码目录是否存在,没有则新建
if (!FileUtil.exist(globalCodePathName)) {
FileUtil.mkdir(globalCodePathName);
}
// 把用户的代码隔离存放
String userCodeParentPath = globalCodePathName + File.separator + UUID.randomUUID();
String userCodePath = userCodeParentPath + File.separator + GLOBAL_JAVA_CLASS_NAME;
File userCodeFile = FileUtil.writeString(code, userCodePath, StandardCharsets.UTF_8);
return userCodeFile;
}
想编译代码 获得class文件
/**
* 2、编译代码
* @param userCodeFile
* @return
*/
public ExecuteMessage compileFile(File userCodeFile) {
String compileCmd = String.format("javac -encoding utf-8 %s", userCodeFile.getAbsolutePath());
try {
Process compileProcess = Runtime.getRuntime().exec(compileCmd);
ExecuteMessage executeMessage = ProcessUtils.runProcessAndGetMessage(compileProcess, "编译");
if (executeMessage.getExitValue() != 0) {
throw new RuntimeException("编译错误");
}
return executeMessage;
} catch (Exception e) {
// return getErrorResponse(e);
throw new RuntimeException(e);
}
}
执行文件,获得执行结果列表
/**
* 3、执行文件,获得执行结果列表
* @param userCodeFile
* @param inputList
* @return
*/
public List<ExecuteMessage> runFile(File userCodeFile, List<String> inputList) {
String userCodeParentPath = userCodeFile.getParentFile().getAbsolutePath();
List<ExecuteMessage> executeMessageList = new ArrayList<>();
for (String inputArgs : inputList) {
// String runCmd = String.format("java -Xmx256m -Dfile.encoding=UTF-8 -cp %s Main %s", userCodeParentPath, inputArgs);
String runCmd = String.format("java -Xmx256m -Dfile.encoding=UTF-8 -cp %s Main %s", userCodeParentPath, inputArgs);
try {
Process runProcess = Runtime.getRuntime().exec(runCmd);
// 超时控制
new Thread(() -> {
try {
Thread.sleep(TIME_OUT);
System.out.println("超时了,中断");
runProcess.destroy();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}).start();
ExecuteMessage executeMessage = ProcessUtils.runProcessAndGetMessage(runProcess, "运行");
System.out.println(executeMessage);
executeMessageList.add(executeMessage);
} catch (Exception e) {
throw new RuntimeException("执行错误", e);
}
}
return executeMessageList;
}
获取输出结果
/**
* 4、获取输出结果
* @param executeMessageList
* @return
*/
public ExecuteCodeResponse getOutputResponse(List<ExecuteMessage> executeMessageList) {
ExecuteCodeResponse executeCodeResponse = new ExecuteCodeResponse();
List<String> outputList = new ArrayList<>();
// 取用时最大值,便于判断是否超时
long maxTime = 0;
for (ExecuteMessage executeMessage : executeMessageList) {
String errorMessage = executeMessage.getErrorMessage();
if (StrUtil.isNotBlank(errorMessage)) {
executeCodeResponse.setMessage(errorMessage);
// 用户提交的代码执行中存在错误
executeCodeResponse.setStatus(3);
break;
}
outputList.add(executeMessage.getMessage());
Long time = executeMessage.getTime();
if (time != null) {
maxTime = Math.max(maxTime, time);
}
}
// 正常运行完成
if (outputList.size() == executeMessageList.size()) {
executeCodeResponse.setStatus(1);
}
executeCodeResponse.setOutputList(outputList);
JudgeInfo judgeInfo = new JudgeInfo();
judgeInfo.setTime(maxTime);
// 要借助第三方库来获取内存占用,非常麻烦,此处不做实现
// judgeInfo.setMemory();
executeCodeResponse.setJudgeInfo(judgeInfo);
return executeCodeResponse;
}
删除文件
/**
* 5、删除文件
* @param userCodeFile
* @return
*/
public boolean deleteFile(File userCodeFile) {
if (userCodeFile.getParentFile() != null) {
String userCodeParentPath = userCodeFile.getParentFile().getAbsolutePath();
boolean del = FileUtil.del(userCodeParentPath);
System.out.println("删除" + (del ? "成功" : "失败"));
return del;
}
return true;
}
获取错误响应
/**
* 6、获取错误响应
*
* @param e
* @return
*/
private ExecuteCodeResponse getErrorResponse(Throwable e) {
ExecuteCodeResponse executeCodeResponse = new ExecuteCodeResponse();
executeCodeResponse.setOutputList(new ArrayList<>());
executeCodeResponse.setMessage(e.getMessage());
// 表示代码沙箱错误
executeCodeResponse.setStatus(2);
executeCodeResponse.setJudgeInfo(new JudgeInfo());
return executeCodeResponse;
}
我们之前流程化程序
这样提成一一个方法就会很简单
自上而下一层一层去抽取
我们此时重写父类中的方法
@Override
public ExecuteCodeResponse executeCode(ExecuteCodeRequest executeCodeRequest) {
List<String> inputList = executeCodeRequest.getInputList();
String code = executeCodeRequest.getCode();
String language = executeCodeRequest.getLanguage();
// 1. 把用户的代码保存为文件
File userCodeFile = saveCodeToFile(code);
// 2. 编译代码,得到 class 文件
ExecuteMessage compileFileExecuteMessage = compileFile(userCodeFile);
System.out.println(compileFileExecuteMessage);
// 3. 执行代码,得到输出结果
List<ExecuteMessage> executeMessageList = runFile(userCodeFile, inputList);
// 4. 收集整理输出结果
ExecuteCodeResponse outputResponse = getOutputResponse(executeMessageList);
// 5. 文件清理
boolean b = deleteFile(userCodeFile);
if (!b) {
log.error("deleteFile error, userCodeFilePath = {}", userCodeFile.getAbsolutePath());
}
return outputResponse;
}
现在我们就能定义子类
直接复用父类的模版方法就行
这就是模版方法设计模式 Template
package com.yupi.yuojcodesandbox;
import com.yupi.yuojcodesandbox.model.ExecuteCodeRequest;
import com.yupi.yuojcodesandbox.model.ExecuteCodeResponse;
import org.springframework.stereotype.Component;
/**
* Java 原生代码沙箱实现(直接复用模板方法)
*/
@Component
public class JavaNativeCodeSandbox extends JavaCodeSandboxTemplate {
@Override
public ExecuteCodeResponse executeCode(ExecuteCodeRequest executeCodeRequest) {
return super.executeCode(executeCodeRequest);
}
}
我们这边如果在子类里面重写了父类的某个方法
那么子类执行的时候
优先的是去执行子类里面重写方法 就是已经覆盖过的实现
如果没有重写 是去执行模版中的默认实现
接下来我们要去改造一下docker代码沙箱的实现
继承我们的模版方法
自定义某些方法的实现
我们发现我们写JavaDockerCodeSandbox类就是重写了创建容器类
package com.yupi.yuojcodesandbox;
import cn.hutool.core.io.FileUtil;
import cn.hutool.core.util.StrUtil;
import com.yupi.yuojcodesandbox.model.ExecuteCodeRequest;
import com.yupi.yuojcodesandbox.model.ExecuteCodeResponse;
import com.yupi.yuojcodesandbox.model.ExecuteMessage;
import com.yupi.yuojcodesandbox.model.JudgeInfo;
import com.yupi.yuojcodesandbox.utils.ProcessUtils;
import lombok.extern.slf4j.Slf4j;
import java.io.File;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
/**
* Java 代码沙箱模板方法的实现
*/
@Slf4j
public abstract class JavaCodeSandboxTemplate implements CodeSandbox {
private static final String GLOBAL_CODE_DIR_NAME = "tmpCode";
private static final String GLOBAL_JAVA_CLASS_NAME = "Main.java";
private static final long TIME_OUT = 5000L;
/**
*
* @param executeCodeRequest 代码请求
* @return
*/
@Override
public ExecuteCodeResponse executeCode(ExecuteCodeRequest executeCodeRequest) {
List<String> inputList = executeCodeRequest.getInputList();
String code = executeCodeRequest.getCode();
String language = executeCodeRequest.getLanguage();
// 1. 把用户的代码保存为文件
File userCodeFile = saveCodeToFile(code);
// 2. 编译代码,得到 class 文件
ExecuteMessage compileFileExecuteMessage = compileFile(userCodeFile);
System.out.println(compileFileExecuteMessage);
// 3. 执行代码,得到输出结果
List<ExecuteMessage> executeMessageList = runFile(userCodeFile, inputList);
// 4. 收集整理输出结果
ExecuteCodeResponse outputResponse = getOutputResponse(executeMessageList);
// 5. 文件清理
boolean b = deleteFile(userCodeFile);
if (!b) {
log.error("deleteFile error, userCodeFilePath = {}", userCodeFile.getAbsolutePath());
}
return outputResponse;
}
/**
* 1. 把用户的代码保存为文件
* @param code 用户代码
* @return
*/
public File saveCodeToFile(String code) {
String userDir = System.getProperty("user.dir");
String globalCodePathName = userDir + File.separator + GLOBAL_CODE_DIR_NAME;
// 判断全局代码目录是否存在,没有则新建
if (!FileUtil.exist(globalCodePathName)) {
FileUtil.mkdir(globalCodePathName);
}
// 把用户的代码隔离存放
String userCodeParentPath = globalCodePathName + File.separator + UUID.randomUUID();
String userCodePath = userCodeParentPath + File.separator + GLOBAL_JAVA_CLASS_NAME;
File userCodeFile = FileUtil.writeString(code, userCodePath, StandardCharsets.UTF_8);
return userCodeFile;
}
/**
* 2、编译代码
* @param userCodeFile 保存的文件
* @return
*/
public ExecuteMessage compileFile(File userCodeFile) {
String compileCmd = String.format("javac -encoding utf-8 %s", userCodeFile.getAbsolutePath());
try {
Process compileProcess = Runtime.getRuntime().exec(compileCmd);
ExecuteMessage executeMessage = ProcessUtils.runProcessAndGetMessage(compileProcess, "编译");
if (executeMessage.getExitValue() != 0) {
throw new RuntimeException("编译错误");
}
return executeMessage;
} catch (Exception e) {
// return getErrorResponse(e);
throw new RuntimeException(e);
}
}
/**
* 3、执行文件,获得执行结果列表
* @param userCodeFile
* @param inputList
* @return
*/
public List<ExecuteMessage> runFile(File userCodeFile, List<String> inputList) {
String userCodeParentPath = userCodeFile.getParentFile().getAbsolutePath();
List<ExecuteMessage> executeMessageList = new ArrayList<>();
for (String inputArgs : inputList) {
// String runCmd = String.format("java -Xmx256m -Dfile.encoding=UTF-8 -cp %s Main %s", userCodeParentPath, inputArgs);
String runCmd = String.format("java -Xmx256m -Dfile.encoding=UTF-8 -cp %s Main %s", userCodeParentPath, inputArgs);
try {
Process runProcess = Runtime.getRuntime().exec(runCmd);
// 超时控制
new Thread(() -> {
try {
Thread.sleep(TIME_OUT);
System.out.println("超时了,中断");
runProcess.destroy();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}).start();
ExecuteMessage executeMessage = ProcessUtils.runProcessAndGetMessage(runProcess, "运行");
System.out.println(executeMessage);
executeMessageList.add(executeMessage);
} catch (Exception e) {
throw new RuntimeException("执行错误", e);
}
}
return executeMessageList;
}
/**
* 4、获取输出结果
* @param executeMessageList
* @return
*/
public ExecuteCodeResponse getOutputResponse(List<ExecuteMessage> executeMessageList) {
ExecuteCodeResponse executeCodeResponse = new ExecuteCodeResponse();
List<String> outputList = new ArrayList<>();
// 取用时最大值,便于判断是否超时
long maxTime = 0;
for (ExecuteMessage executeMessage : executeMessageList) {
String errorMessage = executeMessage.getErrorMessage();
if (StrUtil.isNotBlank(errorMessage)) {
executeCodeResponse.setMessage(errorMessage);
// 用户提交的代码执行中存在错误
executeCodeResponse.setStatus(3);
break;
}
outputList.add(executeMessage.getMessage());
Long time = executeMessage.getTime();
if (time != null) {
maxTime = Math.max(maxTime, time);
}
}
// 正常运行完成
if (outputList.size() == executeMessageList.size()) {
executeCodeResponse.setStatus(1);
}
executeCodeResponse.setOutputList(outputList);
JudgeInfo judgeInfo = new JudgeInfo();
judgeInfo.setTime(maxTime);
// 要借助第三方库来获取内存占用,非常麻烦,此处不做实现
// judgeInfo.setMemory();
executeCodeResponse.setJudgeInfo(judgeInfo);
return executeCodeResponse;
}
/**
* 5、删除文件
* @param userCodeFile
* @return
*/
public boolean deleteFile(File userCodeFile) {
if (userCodeFile.getParentFile() != null) {
String userCodeParentPath = userCodeFile.getParentFile().getAbsolutePath();
boolean del = FileUtil.del(userCodeParentPath);
System.out.println("删除" + (del ? "成功" : "失败"));
return del;
}
return true;
}
/**
* 6、获取错误响应
*
* @param e
* @return
*/
private ExecuteCodeResponse getErrorResponse(Throwable e) {
ExecuteCodeResponse executeCodeResponse = new ExecuteCodeResponse();
executeCodeResponse.setOutputList(new ArrayList<>());
executeCodeResponse.setMessage(e.getMessage());
// 表示代码沙箱错误
executeCodeResponse.setStatus(2);
executeCodeResponse.setJudgeInfo(new JudgeInfo());
return executeCodeResponse;
}
}
熟悉模版方法模式的人一看就懂这个方法
package com.yupi.yuojcodesandbox;
import cn.hutool.core.io.FileUtil;
import cn.hutool.core.util.StrUtil;
import com.yupi.yuojcodesandbox.model.ExecuteCodeRequest;
import com.yupi.yuojcodesandbox.model.ExecuteCodeResponse;
import com.yupi.yuojcodesandbox.model.ExecuteMessage;
import com.yupi.yuojcodesandbox.model.JudgeInfo;
import com.yupi.yuojcodesandbox.utils.ProcessUtils;
import lombok.extern.slf4j.Slf4j;
import java.io.File;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
/**
* Java 代码沙箱模板方法的实现
*/
@Slf4j
public abstract class JavaCodeSandboxTemplate implements CodeSandbox {
private static final String GLOBAL_CODE_DIR_NAME = "tmpCode";
private static final String GLOBAL_JAVA_CLASS_NAME = "Main.java";
private static final long TIME_OUT = 5000L;
/**
*
* @param executeCodeRequest 代码请求
* @return
*/
@Override
public ExecuteCodeResponse executeCode(ExecuteCodeRequest executeCodeRequest) {
List<String> inputList = executeCodeRequest.getInputList();
String code = executeCodeRequest.getCode();
String language = executeCodeRequest.getLanguage();
// 1. 把用户的代码保存为文件
File userCodeFile = saveCodeToFile(code);
// 2. 编译代码,得到 class 文件
ExecuteMessage compileFileExecuteMessage = compileFile(userCodeFile);
System.out.println(compileFileExecuteMessage);
// 3. 执行代码,得到输出结果
List<ExecuteMessage> executeMessageList = runFile(userCodeFile, inputList);
// 4. 收集整理输出结果
ExecuteCodeResponse outputResponse = getOutputResponse(executeMessageList);
// 5. 文件清理
boolean b = deleteFile(userCodeFile);
if (!b) {
log.error("deleteFile error, userCodeFilePath = {}", userCodeFile.getAbsolutePath());
}
return outputResponse;
}
/**
* 1. 把用户的代码保存为文件
* @param code 用户代码
* @return
*/
public File saveCodeToFile(String code) {
String userDir = System.getProperty("user.dir");
String globalCodePathName = userDir + File.separator + GLOBAL_CODE_DIR_NAME;
// 判断全局代码目录是否存在,没有则新建
if (!FileUtil.exist(globalCodePathName)) {
FileUtil.mkdir(globalCodePathName);
}
// 把用户的代码隔离存放
String userCodeParentPath = globalCodePathName + File.separator + UUID.randomUUID();
String userCodePath = userCodeParentPath + File.separator + GLOBAL_JAVA_CLASS_NAME;
File userCodeFile = FileUtil.writeString(code, userCodePath, StandardCharsets.UTF_8);
return userCodeFile;
}
/**
* 2、编译代码
* @param userCodeFile 保存的文件
* @return
*/
public ExecuteMessage compileFile(File userCodeFile) {
String compileCmd = String.format("javac -encoding utf-8 %s", userCodeFile.getAbsolutePath());
try {
Process compileProcess = Runtime.getRuntime().exec(compileCmd);
ExecuteMessage executeMessage = ProcessUtils.runProcessAndGetMessage(compileProcess, "编译");
if (executeMessage.getExitValue() != 0) {
throw new RuntimeException("编译错误");
}
return executeMessage;
} catch (Exception e) {
// return getErrorResponse(e);
throw new RuntimeException(e);
}
}
/**
* 3、执行文件,获得执行结果列表
* @param userCodeFile
* @param inputList
* @return
*/
public List<ExecuteMessage> runFile(File userCodeFile, List<String> inputList) {
String userCodeParentPath = userCodeFile.getParentFile().getAbsolutePath();
List<ExecuteMessage> executeMessageList = new ArrayList<>();
for (String inputArgs : inputList) {
// String runCmd = String.format("java -Xmx256m -Dfile.encoding=UTF-8 -cp %s Main %s", userCodeParentPath, inputArgs);
String runCmd = String.format("java -Xmx256m -Dfile.encoding=UTF-8 -cp %s Main %s", userCodeParentPath, inputArgs);
try {
Process runProcess = Runtime.getRuntime().exec(runCmd);
// 超时控制
new Thread(() -> {
try {
Thread.sleep(TIME_OUT);
System.out.println("超时了,中断");
runProcess.destroy();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}).start();
ExecuteMessage executeMessage = ProcessUtils.runProcessAndGetMessage(runProcess, "运行");
System.out.println(executeMessage);
executeMessageList.add(executeMessage);
} catch (Exception e) {
throw new RuntimeException("执行错误", e);
}
}
return executeMessageList;
}
/**
* 4、获取输出结果
* @param executeMessageList
* @return
*/
public ExecuteCodeResponse getOutputResponse(List<ExecuteMessage> executeMessageList) {
ExecuteCodeResponse executeCodeResponse = new ExecuteCodeResponse();
List<String> outputList = new ArrayList<>();
// 取用时最大值,便于判断是否超时
long maxTime = 0;
for (ExecuteMessage executeMessage : executeMessageList) {
String errorMessage = executeMessage.getErrorMessage();
if (StrUtil.isNotBlank(errorMessage)) {
executeCodeResponse.setMessage(errorMessage);
// 用户提交的代码执行中存在错误
executeCodeResponse.setStatus(3);
break;
}
outputList.add(executeMessage.getMessage());
Long time = executeMessage.getTime();
if (time != null) {
maxTime = Math.max(maxTime, time);
}
}
// 正常运行完成
if (outputList.size() == executeMessageList.size()) {
executeCodeResponse.setStatus(1);
}
executeCodeResponse.setOutputList(outputList);
JudgeInfo judgeInfo = new JudgeInfo();
judgeInfo.setTime(maxTime);
// 要借助第三方库来获取内存占用,非常麻烦,此处不做实现
// judgeInfo.setMemory();
executeCodeResponse.setJudgeInfo(judgeInfo);
return executeCodeResponse;
}
/**
* 5、删除文件
* @param userCodeFile
* @return
*/
public boolean deleteFile(File userCodeFile) {
if (userCodeFile.getParentFile() != null) {
String userCodeParentPath = userCodeFile.getParentFile().getAbsolutePath();
boolean del = FileUtil.del(userCodeParentPath);
System.out.println("删除" + (del ? "成功" : "失败"));
return del;
}
return true;
}
/**
* 6、获取错误响应
*
* @param e
* @return
*/
private ExecuteCodeResponse getErrorResponse(Throwable e) {
ExecuteCodeResponse executeCodeResponse = new ExecuteCodeResponse();
executeCodeResponse.setOutputList(new ArrayList<>());
executeCodeResponse.setMessage(e.getMessage());
// 表示代码沙箱错误
executeCodeResponse.setStatus(2);
executeCodeResponse.setJudgeInfo(new JudgeInfo());
return executeCodeResponse;
}
}