from nnunetv2.imageio.natural_image_reader_writer import NaturalImage2DIO
from torch import device
import os
from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
import numpy as np
from time import time
from cv2 import imwrite
if __name__ =="__main__":
tic = time()# instantiate the nnUNetPredictor
predictor = nnUNetPredictor(
tile_step_size=0.5,
use_gaussian=False,
use_mirroring=False,
perform_everything_on_device=False,# device=device("cuda", 0),
device=device("cpu"),
verbose=False,
verbose_preprocessing=False,
allow_tqdm=False,)
predictor.initialize_from_trained_model_folder(# 直接使用绝对路径,替换join方法
model_training_output_dir="/xxx/nnUNet/nnUNet_results/xxx/nnUNetTrainer__nnUNetPlans__2d",
use_folds=(0,),
checkpoint_name="checkpoint_best.pth",)
image_path_1 ="./img_1_0000.png"
img1, props1 = NaturalImage2DIO().read_images([image_path_1])
img1 = np.array(img1, dtype=np.float32)
image_path_2 ="./img_2_0000.png"
img2, props2 = NaturalImage2DIO().read_images([image_path_2])# we do not set output files so that the segmentations will be returned. You can of course also specify output# files instead (no return value on that case)
rets = predictor.predict_from_list_of_npy_arrays(
image_or_list_of_images=[img1, img2],
segs_from_prev_stage_or_list_of_segs_from_prev_stage=None,
properties_or_list_of_properties=[props1, props2],
truncated_ofname=None,
num_processes=1,
save_probabilities=False,
num_processes_segmentation_export=1,)# print(type(rets)) # <class 'list'>
os.makedirs("./multiple_npy_pred_recover", exist_ok=True)
image_pred_path_1 ="./multiple_npy_pred_recover/img_1_0000_npy_pred.png"
image_pred_path_2 ="./multiple_npy_pred_recover/img_2_0000_npy_pred.png"
image_pred_path_list =[image_pred_path_1, image_pred_path_2]for ret, image_pred_path inzip(rets, image_pred_path_list):
ret = ret.astype(np.uint8)# print(f"==>> ret shape: {ret.shape}")# print(f"==>> ret type: {type(ret)}")
ret = ret.transpose((1,2,0))# 恢复自定义的标签值
predict_recover_value_dict ={1:128,2:196,3:255,}for(
predict_value, recover_value
)in predict_recover_value_dict.items():
ret[ret == predict_value]= recover_value
imwrite(image_pred_path, ret)print(f"==>> time cost: {time()- tic}")