Implementing Text‑Based Image Search Using OCR, Transformers, and Vector Databases
This article explains how to build a text‑to‑image search system by first extracting text with OCR, then storing image paths and textual embeddings in a SQLite or Milvus vector database, and finally improving retrieval with Transformer‑based sentence embeddings and image‑captioning models.
1. Introduction
The previous article described an image‑by‑image search using convolutional neural networks and a vector database; this one focuses on searching images by textual queries.
2. OCR + Text Search
OCR (Optical Character Recognition) extracts visible text from images using tools such as Tesseract. The extracted strings are stored together with image file paths in a database, enabling simple fuzzy text queries.
2.1 Text Recognition
import os, cv2
import numpy as np
import pytesseract
from tqdm import tqdm
from PIL import Image
from sqlalchemy import create_engine, String, select
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, Session
base_path = "G:\datasets\emoji"
files = [os.path.join(base_path, file) for file in os.listdir(base_path) if file.endswith(".jpg")]
for file in files:
try:
image = Image.open(file)
string = pytesseract.image_to_string(image, lang='chi_sim')
print(file, ":", string.strip())
except Exception as e:
pass2.2 Storing in SQLite
from sqlalchemy import create_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from sqlalchemy import String
class Base(DeclarativeBase):
pass
class ImageInformation(Base):
__tablename__ = "image_information"
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
filepath: Mapped[str] = mapped_column(String(255))
content: Mapped[str] = mapped_column(String(255))
def __repr__(self) -> str:
return f"User(id={self.id!r}, filepath={self.filepath!r}, content={self.content!r})"
engine = create_engine("sqlite:///image_search.db", echo=False)
Base.metadata.create_all(engine)2.3 Inserting OCR Results
base_path = "G:\datasets\emoji"
files = [os.path.join(base_path, file) for file in os.listdir(base_path) if file.endswith(".jpg")]
bar = tqdm(total=len(files))
for file in files:
try:
image = Image.open(file)
string = pytesseract.image_to_string(image, lang='chi_sim').strip()
file = file[:255] if len(file) > 255 else file
string = string[:255] if len(string) > 255 else string
with Session(engine) as session:
info = ImageInformation(filepath=file, content=string)
session.add_all([info])
session.commit()
except Exception as e:
pass
bar.update(1)2.4 Simple Text Query
keyword = '你好'
w, h = 224, 224
with Session(engine) as session:
stmt = select(ImageInformation).where(ImageInformation.content.contains(keyword)).limit(8)
images = [cv2.resize(cv2.imread(ii.filepath), (w, h)) for ii in session.scalars(stmt)]
if len(images) > 0:
result = np.hstack(images)
cv2.imwrite("result.jpg", result)
else:
print("没有找到结果")This approach works for short queries but fails on longer, semantically rich texts.
3. Transformer‑Based Improvement
To overcome pure string matching limitations, the pipeline now encodes OCR‑extracted text with a Transformer model, stores the resulting embeddings in a vector database (Milvus), and performs similarity search on embeddings.
3.1 Creating a Milvus Collection
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility
connections.connect(host='127.0.0.1', port='19530')
def create_milvus_collection(collection_name, dim):
if utility.has_collection(collection_name):
utility.drop_collection(collection_name)
fields = [
FieldSchema(name='id', dtype=DataType.INT64, descrition='ids', max_length=500, is_primary=True, auto_id=True),
FieldSchema(name='filepath', dtype=DataType.VARCHAR, description='filepath', max_length=512),
FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, descrition='embedding vectors', dim=dim),
]
schema = CollectionSchema(fields=fields, description='reverse image search')
collection = Collection(name=collection_name, schema=schema)
index_params = {'metric_type': 'L2', 'index_type': "IVF_FLAT", 'params': {"nlist": 2048}}
collection.create_index(field_name="embedding", index_params=index_params)
return collection
collection = create_milvus_collection('image_information', 768)3.2 Text‑to‑Vector (text2vec)
from text2vec import SentenceModel
model = SentenceModel('shibing624/text2vec-base-chinese')
embeddings = model.encode(['不要温顺地走进那个良夜'])
print(embeddings.shape) # (1, 768)3.3 Storing Embeddings
from text2vec import SentenceModel
model = SentenceModel('shibing624/text2vec-base-chinese', device="cuda")
base_path = "G:\datasets\emoji"
files = [os.path.join(base_path, file) for file in os.listdir(base_path) if file.endswith(".jpg")]
bar = tqdm(total=len(files))
for idx, file in enumerate(files):
try:
image = Image.open(file)
string = pytesseract.image_to_string(image, lang='chi_sim').strip()
embedding = model.encode([string])[0]
collection.insert([[file], [embedding]])
except Exception as e:
pass
bar.update(1)3.4 Embedding‑Based Search
import cv2
import numpy as np
from text2vec import SentenceModel
from pymilvus import connections, Collection
model = SentenceModel('shibing624/text2vec-base-chinese', device="cuda")
connections.connect(host='127.0.0.1', port='19530')
collection = Collection(name='image_information')
search_params = {"metric_type": "L2", "params": {"nprobe": 10}, "offset": 5}
collection.load()
keyword = "今天不开心"
embedding = model.encode([keyword])
results = collection.search(data=[embedding[0]], anns_field='embedding', param=search_params, output_fields=['filepath'], limit=10, consistency_level="Strong")
collection.release()
w, h = 224, 224
images = []
for result in results[0]:
filepath = result.entity.get('filepath')
img = cv2.resize(cv2.imread(filepath), (w, h))
images.append(np.array(img))
result_img = np.hstack(images)
cv2.imwrite("result.jpg", result_img)Embedding search retrieves semantically related images even when the exact characters are absent.
4. Image‑Captioning Based Search
When images contain no readable text, an Image‑to‑Text (captioning) model generates descriptive sentences, which are then encoded and stored similarly.
4.1 Captioning Model Loading
import os
import torch
from tqdm import tqdm
from PIL import Image
from text2vec import SentenceModel
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
from pymilvus import Collection
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_model():
"""Load the sentence model and the image‑captioning model."""
sentence_model = SentenceModel('shibing624/text2vec-base-chinese', device="cuda")
model = VisionEncoderDecoderModel.from_pretrained("bipin/image-caption-generator")
image_processor = ViTImageProcessor.from_pretrained("bipin/image-caption-generator")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model.to(device)
return sentence_model, model, image_processor, tokenizer
def get_embedding(filepath):
"""Convert an image to a caption vector."""
pixel_values = image_processor(images=[Image.open(filepath)], return_tensors="pt").pixel_values.to(device)
output_ids = model.generate(pixel_values, num_beams=4, max_length=128)
pred = tokenizer.decode(output_ids[0], skip_special_tokens=True)
return sentence_model.encode(pred)4.2 Inserting Caption Embeddings
connections.connect(host='127.0.0.1', port='19530')
collection = Collection("image_information")
collection.load()
sentence_model, model, image_processor, tokenizer = load_model()
base_path = "G:\datasets\people"
files = [os.path.join(base_path, file) for file in os.listdir(base_path)]
bar = tqdm(total=len(files))
for idx, file in enumerate(files):
try:
embedding = get_embedding(file)
collection.insert([[file], [embedding]])
except Exception as e:
pass
bar.update(1)4.3 Caption‑Based Search
search_params = {"metric_type": "L2", "params": {"nprobe": 10}, "offset": 5}
keyword = "girl"
embedding = sentence_model.encode([keyword])
results = collection.search(data=[embedding[0]], anns_field='embedding', param=search_params, output_fields=['filepath'], limit=10, consistency_level="Strong")
collection.release()
w, h = 224, 224
images = []
for result in results[0]:
filepath = result.entity.get('filepath')
img = cv2.resize(cv2.imread(filepath), (w, h))
images.append(np.array(img))
result_img = np.hstack(images)
cv2.imwrite("result.jpg", result_img)Using multiple generated captions (by adjusting the temperature parameter) can improve recall for ambiguous images.
4.4 Temperature‑Controlled Generation
output_ids = model.generate(pixel_values, num_beams=4, max_length=128, temperature=0.8)Lower temperature yields more deterministic captions, while higher values introduce diversity.
Rare Earth Juejin Tech Community
Juejin, a tech community that helps developers grow.
How this landed with the community
Was this worth your time?
0 Comments
Thoughtful readers leave field notes, pushback, and hard-won operational detail here.