一、安装包
pip install transformers
pip install torch
pip install SentencePiece
pip install timm
pip install accelerate
pip install pytesseract pillow pandas
pip install tesseract
下载模型:
https://huggingface.co/microsoft/table-transformer-structure-recognition/tree/main
https://huggingface.co/microsoft/table-transformer-detection/tree/main
二、安装tesseract-ocr
我这里用的windows
下载:tesseract-ocr-w64-setup-5.4.0.20240606.exe 安装
https://tesseract-ocr.github.io/tessdoc/Downloads.html
https://digi.bib.uni-mannheim.de/tesseract/ 【tesseract-ocr-w64-setup-5.4.0.20240606.exe】
添加环境变量:
三、准备图片
下载:https://download.csdn.net/download/xiaoxionglove/90063200
四、编写代码
from PIL import Image
from transformers import DetrImageProcessor
from transformers import TableTransformerForObjectDetection
import torch
import matplotlib.pyplot as plt
import os
import psutil
import time
from transformers import DetrFeatureExtractor
feature_extractor = DetrFeatureExtractor()
import pandas as pd
import pytesseract
model = TableTransformerForObjectDetection.from_pretrained("microsoft/table-transformer-detection")
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
[0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
def plot_results(pil_img, scores, labels, boxes):
plt.figure(figsize=(16,10))
plt.imshow(pil_img)
ax = plt.gca()
colors = COLORS * 100
for score, label, (xmin, ymin, xmax, ymax),c in zip(scores.tolist(), labels.tolist(), boxes.tolist(), colors):
ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
fill=False, color=c, linewidth=3))
text = f'{model.config.id2label[label]}: {score:0.2f}'
ax.text(xmin, ymin, text, fontsize=15,
bbox=dict(facecolor='yellow', alpha=0.5))
plt.axis('off')
plt.show()
def table_detection(file_path):
image = Image.open(file_path).convert("RGB")
width, height = image.size
image.resize((int(width *0.5), int(height *0.5)))
feature_extractor = DetrImageProcessor()
encoding = feature_extractor(image, return_tensors="pt")
with torch.no_grad():
outputs = model(**encoding)
width, height = image.size
results = feature_extractor.post_process_object_detection(outputs, threshold=0.7, target_sizes=[(height, width)])[0]
plot_results(image, results['scores'], results['labels'], results['boxes'])
return results['boxes']
ram_usage = psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024
print(f"ram usage : {ram_usage}")
count = 0
root = "Detection_Images_Test/"
for file in os.listdir(root):
file_path = os.path.join(root, file)
start_time = time.time()
pred_bbox = table_detection(file_path)
count += 1
end_time = time.time()
time_usage = end_time - start_time
ram_usage = psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024
print(f"Iteration {count + 1} - RAM Usage: {ram_usage:.2f} MB, Time Usage: {time_usage:.2f} seconds")
if count > 2:
break
file = 'img_test/PMC1064078_table_0.jpg.png'
image = Image.open(file).convert("RGB")
image
from huggingface_hub import hf_hub_download
from PIL import Image
from transformers import TableTransformerForObjectDetection
model = TableTransformerForObjectDetection.from_pretrained("microsoft/table-transformer-structure-recognition")
def cell_detection(file_path):
image = Image.open(file_path).convert("RGB")
width, height = image.size
image.resize((int(width*0.5), int(height*0.5)))
encoding = feature_extractor(image, return_tensors="pt")
encoding.keys()
with torch.no_grad():
outputs = model(**encoding)
target_sizes = [image.size[::-1]]
results = feature_extractor.post_process_object_detection(outputs, threshold=0.6, target_sizes=target_sizes)[0]
plot_results(image, results['scores'], results['labels'], results['boxes'])
model.config.id2label
ram_usage = psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024
print(f"ram usage : {ram_usage}")
count = 0
root = "img_test/"
for file in os.listdir(root):
file_path = os.path.join(root, file)
start_time = time.time()
cell_detection(file_path)
count += 1
end_time = time.time()
time_usage = end_time - start_time
ram_usage = psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024
print(f"Iteration {count + 1} - RAM Usage: {ram_usage:.2f} MB, Time Usage: {time_usage:.2f} seconds")
if (count > 2):
break
def plot_results_specific(pil_img, scores, labels, boxes,lab):
plt.figure(figsize=(16, 10))
plt.imshow(pil_img)
ax = plt.gca()
colors = COLORS * 100
for score, label, (xmin, ymin, xmax, ymax), c in zip(scores.tolist(), labels.tolist(), boxes.tolist(), colors):
if label == lab:
ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
fill=False, color=c, linewidth=3))
text = f'{model.config.id2label[label]}: {score:0.2f}'
ax.text(xmin, ymin, text, fontsize=15,
bbox=dict(facecolor='yellow', alpha=0.5))
plt.axis('off')
plt.show()
def draw_box_specific(image_path,labelnum):
image = Image.open(image_path).convert("RGB")
width, height = image.size
encoding = feature_extractor(image, return_tensors="pt")
with torch.no_grad():
outputs = model(**encoding)
results = feature_extractor.post_process_object_detection(outputs, threshold=0.7, target_sizes=[(height, width)])[0]
plot_results_specific(image, results['scores'], results['labels'], results['boxes'],labelnum)
def compute_boxes(image_path):
image = Image.open(image_path).convert("RGB")
width, height = image.size
encoding = feature_extractor(image, return_tensors="pt")
with torch.no_grad():
outputs = model(**encoding)
results = feature_extractor.post_process_object_detection(outputs, threshold=0.7, target_sizes=[(height, width)])[0]
boxes = results['boxes'].tolist()
labels = results['labels'].tolist()
return boxes,labels
def extract_table(image_path):
image = Image.open(image_path).convert("RGB")
boxes, labels = compute_boxes(image_path)
cell_locations = []
for box_row, label_row in zip(boxes, labels):
if label_row == 2:
for box_col, label_col in zip(boxes, labels):
if label_col == 1:
cell_box = (box_col[0], box_row[1], box_col[2], box_row[3])
cell_locations.append(cell_box)
cell_locations.sort(key=lambda x: (x[1], x[0]))
num_columns = 0
box_old = cell_locations[0]
for box in cell_locations[1:]:
x1, y1, x2, y2 = box
x1_old, y1_old, x2_old, y2_old = box_old
num_columns += 1
if y1 > y1_old:
break
box_old = box
headers = []
for box in cell_locations[:num_columns]:
x1, y1, x2, y2 = box
cell_image = image.crop((x1, y1, x2, y2))
new_width = cell_image.width * 4
new_height = cell_image.height * 4
cell_image = cell_image.resize((new_width, new_height), resample=Image.LANCZOS)
cell_text = pytesseract.image_to_string(cell_image)
headers.append(cell_text.rstrip())
df = pd.DataFrame(columns=headers)
row = []
for box in cell_locations[num_columns:]:
x1, y1, x2, y2 = box
cell_image = image.crop((x1, y1, x2, y2))
new_width = cell_image.width * 4
new_height = cell_image.height * 4
cell_image = cell_image.resize((new_width, new_height), resample=Image.LANCZOS)
cell_text = pytesseract.image_to_string(cell_image)
if len(cell_text) > num_columns:
cell_text = cell_text[:num_columns]
row.append(cell_text.rstrip())
if len(row) == num_columns:
df.loc[len(df)] = row
row = []
return df
image_path = 'img_test/PMC1112589_table_0.jpg'
draw_box_specific(image_path,1)
df = extract_table(image_path)
df.to_csv('data.csv', index=False)
我们将图片中的表格识别并存到csv中