Entire Space Multi‑Task Model (ESMM) for Post‑Click Conversion Rate Estimation
This article introduces the ESMM (Entire Space Multi‑Task Model) proposed by Alibaba, explaining how it tackles sample selection bias and data sparsity in post‑click conversion rate (CVR) prediction through shared embeddings and implicit pCVR learning, and provides a detailed implementation using the EasyRec framework with code examples.
The article, authored by Gao Yue from Alibaba and sourced from Zhihu, reviews the SIGIR 2018 paper "Entire Space Multi‑Task Model: An Effective Approach for Estimating Post‑Click Conversion Rate". It presents the ESMM model, which addresses two key challenges in CVR estimation: Sample Selection Bias (SSB) and Data Sparsity (DS).
In industrial recommendation systems, tasks extend beyond single‑goal CTR prediction to downstream actions such as comments, favorites, add‑to‑cart, purchases, and watch time. The paper formulates CVR as the probability of conversion given a click, and distinguishes it from CTCVR (conversion given exposure). Traditional CVR models suffer from bias because training data (clicked samples) differ from the full sample space used at inference.
ESMM solves these problems by employing Multi‑Task Learning (MTL) with two tasks: CTR and CTCVR. It shares the same feature embeddings for both tasks and learns an implicit pCVR variable without direct supervision. The model jointly optimizes the two tasks on the entire sample space, allowing CVR to be inferred implicitly from the learned CTR and CTCVR outputs.
The objective function combines the losses of the CTR and CTCVR tasks, effectively learning CVR without requiring explicit labels for unclicked samples.
Implementation details are provided using Alibaba's EasyRec recommendation framework. The following code snippets illustrate the core components:
def build_predict_graph(self):
"""Forward function.
Returns:
self._prediction_dict: Prediction result of two tasks.
"""
# ... (omitted tensor generation logic)
cvr_tower_name = self._cvr_tower_cfg.tower_name
dnn_model = dnn.DNN(self._cvr_tower_cfg.dnn, self._l2_reg, name=cvr_tower_name, is_training=self._is_training)
cvr_tower_output = dnn_model(all_fea)
cvr_tower_output = tf.layers.dense(inputs=cvr_tower_output, units=1, kernel_regularizer=self._l2_reg, name='%s/dnn_output' % cvr_tower_name)
ctr_tower_name = self._ctr_tower_cfg.tower_name
dnn_model = dnn.DNN(self._ctr_tower_cfg.dnn, self._l2_reg, name=ctr_tower_name, is_training=self._is_training)
ctr_tower_output = dnn_model(all_fea)
ctr_tower_output = tf.layers.dense(inputs=ctr_tower_output, units=1, kernel_regularizer=self._l2_reg, name='%s/dnn_output' % ctr_tower_name)
tower_outputs = {cvr_tower_name: cvr_tower_output, ctr_tower_name: ctr_tower_output}
self._add_to_prediction_dict(tower_outputs)
return self._prediction_dict def build_loss_graph(self):
"""Build loss graph.
Returns:
self._loss_dict: Weighted loss of ctr and cvr.
"""
cvr_tower_name = self._cvr_tower_cfg.tower_name
ctr_tower_name = self._ctr_tower_cfg.tower_name
cvr_label_name = self._label_name_dict[cvr_tower_name]
ctr_label_name = self._label_name_dict[ctr_tower_name]
ctcvr_label = tf.cast(self._labels[cvr_label_name] * self._labels[ctr_label_name], tf.float32)
cvr_loss = tf.keras.backend.binary_crossentropy(ctcvr_label, self._prediction_dict['probs_ctcvr'])
cvr_loss = tf.reduce_sum(cvr_loss, name="ctcvr_loss")
self._loss_dict['weighted_cross_entropy_loss_%s' % cvr_tower_name] = self._cvr_tower_cfg.weight * cvr_loss
ctr_loss = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.cast(self._labels[ctr_label_name], tf.float32), logits=self._prediction_dict['logits_%s' % ctr_tower_name]), name="ctr_loss")
self._loss_dict['weighted_cross_entropy_loss_%s' % ctr_tower_name] = self._ctr_tower_cfg.weight * ctr_loss
return self._loss_dict def build_metric_graph(self, eval_config):
"""Build metric graph.
Args:
eval_config: Evaluation configuration.
Returns:
metric_dict: Calculate AUC of ctr, cvr and ctrvr.
"""
metric_dict = {}
cvr_tower_name = self._cvr_tower_cfg.tower_name
ctr_tower_name = self._ctr_tower_cfg.tower_name
cvr_label_name = self._label_name_dict[cvr_tower_name]
ctr_label_name = self._label_name_dict[ctr_tower_name]
for metric in self._cvr_tower_cfg.metrics_set:
ctcvr_label_name = cvr_label_name + '_ctcvr'
cvr_dtype = self._labels[cvr_label_name].dtype
self._labels[ctcvr_label_name] = self._labels[cvr_label_name] * tf.cast(self._labels[ctr_label_name], cvr_dtype)
metric_dict.update(self._build_metric_impl(metric, loss_type=self._cvr_tower_cfg.loss_type, label_name=ctcvr_label_name, num_class=self._cvr_tower_cfg.num_class, suffix='_ctcvr'))
cvr_label_masked_name = cvr_label_name + '_masked'
ctr_mask = self._labels[ctr_label_name] > 0
self._labels[cvr_label_masked_name] = tf.boolean_mask(self._labels[cvr_label_name], ctr_mask)
pred_prefix = 'probs' if self._cvr_tower_cfg.loss_type == LossType.CLASSIFICATION else 'y'
pred_name = '%s_%s' % (pred_prefix, cvr_tower_name)
self._prediction_dict[pred_name + '_masked'] = tf.boolean_mask(self._prediction_dict[pred_name], ctr_mask)
metric_dict.update(self._build_metric_impl(metric, loss_type=self._cvr_tower_cfg.loss_type, label_name=cvr_label_masked_name, num_class=self._cvr_tower_cfg.num_class, suffix='_%s_masked' % cvr_tower_name))
for metric in self._ctr_tower_cfg.metrics_set:
metric_dict.update(self._build_metric_impl(metric, loss_type=self._ctr_tower_cfg.loss_type, label_name=ctr_label_name, num_class=self._ctr_tower_cfg.num_class, suffix='_%s' % ctr_tower_name))
return metric_dictThe article also notes practical observations: experiments on the public AliCCP dataset show a pronounced "seesaw" effect where improving CTR often harms CVR and vice‑versa. Links to the EasyRec repository and the ESMM implementation are provided.
References include the original SIGIR 2018 paper, the EasyRec documentation, and additional resources on multi‑task learning models such as MMoE, PLE, and DBMTL.
DataFunTalk
Dedicated to sharing and discussing big data and AI technology applications, aiming to empower a million data scientists. Regularly hosts live tech talks and curates articles on big data, recommendation/search algorithms, advertising algorithms, NLP, intelligent risk control, autonomous driving, and machine learning/deep learning.
How this landed with the community
Was this worth your time?
0 Comments
Thoughtful readers leave field notes, pushback, and hard-won operational detail here.