LSTM‑Based Advertising Inventory Forecasting with Embedding and Incremental Training at Ctrip
This article presents Ctrip's end‑to‑end solution for precise ad‑inventory forecasting using an LSTM model combined with entity embedding, covering data preprocessing, K‑means clustering, model architecture, offline‑online incremental training, early‑stop mechanisms, evaluation metrics, and Python service deployment.
Background – Accurate ad‑inventory estimation is crucial for commercial ad‑traffic monetization and fine‑grained operational control. Ctrip adopted an LSTM model with embedding to capture long‑term dependencies and holiday effects while supporting multi‑dimensional targeting.
Challenges – Numerous influencing factors (holidays, weekends, disasters), sparse daily samples, high‑dimensional cross‑targeting, and rapidly evolving inventory data requiring frequent model updates.
Algorithm Overview – LSTM, a gated RNN, is used as the backbone; entity embedding converts categorical features into dense vectors, enhancing model generalization.
Data Processing – Features are defined, Z‑score normalization is applied, and K‑means clustering (elbow method) groups similar dimensional combinations to mitigate scale imbalance.
# 计算不同聚类个数下的聚类误差
n = 2
sse = []
for k in range(1, 10):
kmeans = KMeans(n_clusters=k, random_state=0).fit(data)
sse.append(kmeans.inertia_)
# 自动寻找肘部点
diff = np.diff(sse)
diff_r = diff[1:] / diff[:-1]
nice_k = np.argmin(diff_r) + nModel Definition & Training – The LSTM network receives embedded features, processes a 7‑day sliding window, and predicts the 8th day. Batch‑first handling, permute for correct tensor shape, and early‑stop logic are employed.
class LSTM(nn.Module):
def __init__(self, emb_dims, out_dim, hidden_dim, mid_layers):
super(LSTM, self).__init__()
self.emb_layers = nn.ModuleList([nn.Embedding(x, y) for x, y in emb_dims])
self.rnn = nn.LSTM([sum(x) for x in zip(*emb_dims)][1] + 1, hidden_dim, mid_layers, batch_first=False)
self.reg = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, out_dim),
)
def forward(self, cat_data):
var_x = []
for cat in cat_data:
x = self.emb_layer(cat)
var_x.append(x)
stack = torch.stack(var_x)
y = self.rnn(stack)
return yEarly‑stop implementation monitors validation loss and stops training when improvement stalls.
class EarlyStopping:
def __call__(self, val_loss):
score = val_loss
if self.best_score is None:
self.best_score = score
elif score > self.best_score:
if score <= x:
self.counter += 1
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.counter = 0Offline & Online Incremental Training – Offline training on massive historical data produces a base model stored in Redis. Online scripts fetch yesterday's samples via Spark, perform a single gradient step, and overwrite the model in Redis.
# 保存模型
state = {'net': net.state_dict(), 'optimizer': optimizer.state_dict()}
torch.save(state, netPath)
# 存储至Redis
def saveModel(modelName, netPath):
with open(modelPath, "rb") as f:
pth_model = f.read()
result = redis.set(modelName, pth_model)
# 加载模型
def readModel(modelName, path):
pthByte = redis.get(modelName)
with open(path, 'wb') as f:
f.write(pthByte)
return torch.load(path, map_location=lambda storage, loc: storage)Evaluation – Weighted Mean Absolute Percentage Error (WMAPE) from TorchMetrics is used; experiments show lower error with more training batches.
import torchmetrics
error = torchmetrics.WeightedMeanAbsolutePercentageError()
predict = net(valid_x)
error = error(predict, valid_y)
error_list.append(error.item())Deployment – A Flask RESTful service loads the model from Redis, receives a 7‑day feature window, and returns the forecast.
from flask import Flask, request, jsonify
@app.route('/model/ad/forecast', methods=['post'])
def forecast():
prepareData = request.json.get('prepareData')
return forecast(prepareData)Conclusion – The LSTM‑embedding pipeline delivers accurate, multi‑dimensional ad‑inventory forecasts, supports frequent online fine‑tuning, and can be further enhanced by inserting CNN layers between embedding and LSTM to boost feature extraction.
Ctrip Technology
Official Ctrip Technology account, sharing and discussing growth.
How this landed with the community
Was this worth your time?
0 Comments
Thoughtful readers leave field notes, pushback, and hard-won operational detail here.