SpringBoot单体服务无感更新启动
package com. basaltic. warn ;
import cn. hutool. core. io. IoUtil ;
import lombok. SneakyThrows ;
import org. apache. commons. lang3. StringUtils ;
import org. mybatis. spring. annotation. MapperScan ;
import org. springframework. boot. SpringApplication ;
import org. springframework. boot. autoconfigure. SpringBootApplication ;
import org. springframework. boot. web. embedded. tomcat. TomcatServletWebServerFactory ;
import org. springframework. boot. web. servlet. ServletContextInitializer ;
import org. springframework. boot. web. servlet. context. ServletWebServerApplicationContext ;
import org. springframework. boot. web. servlet. server. ServletWebServerFactory ;
import org. springframework. context. ConfigurableApplicationContext ;
import org. springframework. scheduling. annotation. EnableAsync ;
import org. springframework. scheduling. annotation. EnableScheduling ;
import org. springframework. transaction. annotation. EnableTransactionManagement ;
import springfox. documentation. oas. annotations. EnableOpenApi ;
import java. lang. reflect. Method ;
import java. net. ServerSocket ;
import java. util. ArrayList ;
import java. util. concurrent. TimeUnit ;
import java. util. concurrent. atomic. AtomicInteger ;
@MapperScan ( "com.basaltic.warn.sys.mapper" )
@SpringBootApplication
@EnableTransactionManagement
@EnableOpenApi
@EnableScheduling
@EnableAsync
public class BasalticOneNewApplication {
@SneakyThrows
public static void main ( String [ ] args) {
AtomicInteger port = new AtomicInteger ( 7089 ) ;
SpringApplication app = new SpringApplication ( BasalticOneNewApplication . class ) ;
try {
app. addInitializers ( ( context) -> {
String portStr = context. getEnvironment ( ) . getProperty ( "server.port" ) ;
System . out. println ( "The port is: " + portStr) ;
if ( StringUtils . isNotBlank ( portStr) && StringUtils . isNumeric ( portStr) ) {
port. set ( Integer . parseInt ( portStr) ) ;
}
} ) ;
app. run ( args) ;
} catch ( Exception e) {
String [ ] newArgs = args. clone ( ) ;
boolean needChangePort = false ;
if ( isPortInUse ( port. get ( ) ) ) {
newArgs = new String [ args. length + 1 ] ;
System . arraycopy ( args, 0 , newArgs, 0 , args. length) ;
newArgs[ newArgs. length - 1 ] = "--server.port=7069" ;
needChangePort = true ;
}
ConfigurableApplicationContext run = SpringApplication . run ( BasalticOneNewApplication . class , newArgs) ;
if ( needChangePort) {
String osName = System . getProperty ( "os.name" ) ;
while ( isPortInUse ( port. get ( ) ) ) {
if ( osName. startsWith ( "Windows" ) ) {
killWinTidByPort ( port. get ( ) ) ;
} else {
killLinuxTidByPort ( port. get ( ) ) ;
}
System . out. println ( "已经占用" ) ;
TimeUnit . SECONDS. sleep ( 2 ) ;
}
String [ ] beanNames = run. getBeanFactory ( ) . getBeanNamesForType ( ServletWebServerFactory . class ) ;
ServletWebServerFactory webServerFactory = run. getBeanFactory ( ) . getBean ( beanNames[ 0 ] , ServletWebServerFactory . class ) ;
( ( TomcatServletWebServerFactory ) webServerFactory) . setPort ( port. get ( ) ) ;
Method method = ServletWebServerApplicationContext . class . getDeclaredMethod ( "getSelfInitializer" ) ;
method. setAccessible ( true ) ;
ServletContextInitializer invoke = ( ServletContextInitializer ) method. invoke ( run) ;
webServerFactory. getWebServer ( invoke) . start ( ) ;
( ( ServletWebServerApplicationContext ) run) . getWebServer ( ) . stop ( ) ;
}
}
}
private static boolean isPortInUse ( int port) {
try ( ServerSocket serverSocket = new ServerSocket ( port) ) {
return false ;
} catch ( Exception e) {
return true ;
}
}
@SneakyThrows
public static void killLinuxTidByPort ( int port) {
String command = String . format ( "lsof -i :%d | grep LISTEN | awk '{print $2}' | xargs kill -9" , port) ;
Runtime . getRuntime ( ) . exec ( new String [ ] { "sh" , "-c" , command} ) . waitFor ( ) ;
}
@SneakyThrows
public static void killWinTidByPort ( int port) {
Process process = new ProcessBuilder ( "cmd.exe" , "/c" , "taskkill /F /PID " + getWinTidByPort ( port) ) . start ( ) ;
process. waitFor ( 3 , TimeUnit . SECONDS) ;
}
@SneakyThrows
public static int getWinTidByPort ( int port) {
Process process = new ProcessBuilder ( "cmd.exe" , "/c" , "netstat -ano | findstr :" + port) . start ( ) ;
process. waitFor ( 3 , TimeUnit . SECONDS) ;
ArrayList < String > lines = IoUtil . readUtf8Lines ( process. getInputStream ( ) , new ArrayList < > ( ) ) ;
for ( String line : lines) {
if ( StringUtils . isBlank ( line) ) {
continue ;
}
String tidStr = line. substring ( line. trim ( ) . lastIndexOf ( " " ) ) . trim ( ) ;
if ( StringUtils . isNotBlank ( tidStr) && StringUtils . isNumeric ( tidStr) ) {
int tid = Integer . parseInt ( tidStr) ;
if ( tid != 0 ) {
return tid;
}
}
}
return 0 ;
}
}