小分子与蛋白质相互作用深度学习训练流程
以较低成本的方式验证模型训练方案的可行性
前言
本文为NeurIPS小分子蛋白深度学习预测训练方案,使用BELKA数据库。经验证是目前公开方案中在具有时限与执行要求标准下的最佳实践方案。
模型结构与标记
模型架构相对简单,模型结构较为扁平——只有 4 层编码器,每层 8 个头。由于词汇表大小仅为 43 个标记,因此维度固定为 32。我也尝试过 64 和 16,但效果不佳。
可能以一种不太准确的方式使用了 atomInSmiles,最终形成了一个独特的标记方案:每个标记分别为原子(如 C, H, S 等)、数字或方括号中的内容(例如 [C@@])。化学专业人士可以更好地理解这些标记代表的含义。
预训练
在两个阶段从头开始预训练模型:
- MLM:通过标准的掩码标记预测(15% 标记被掩码,其中 80% 被掩盖,10% 被替换为随机标记,10% 保留原样,“BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding”中有相同方法)。我使用了动态小批量 alpha 权重进行 CategoricalFocalCrossEntropy,但仅为实验效果,不确定对结果有无贡献。训练大约 100 个 epoch,每个 epoch 10000 步,每批次 2028 样本,模型共处理数据约 20 次。训练数据包括所有数据——训练集、测试集及外部数据(具体见下文参考)。
- SMILES 转 ECFP(大小=2048,包含手性)。相同模型,只是更换为不同的输出层(使用 Sigmoid 激活的 Dense 层)并锁定嵌入。训练了 20-50 个 epoch。模型表现一般(有很多研究表明,SMILES 编码器难以预测拓扑指纹,这也是该模型的情况),MAP 约为 0.4,但模型确实在此过程中学到了一些有用的表示。
选择 ECFP 主要因为两点:(a) 性能——计算速度较快,尤其是结合 scikit-fingerprints 库,(b) 预测指纹位没有预定义含义(不同于 MACCS 或 PubChem),这是对 SMILES 转换器的挑战,从而有可能获得更好的特征表示。
训练:结合 BELKA 训练集和外部数据。由于外部数据仅包含 sEH 蛋白的标签,模型使用掩码损失和度量。
验证:从训练集中留出 3% 的区块,确保验证集包含至少一个不共享的区块,共约 900 万样本。
设备:A100 GPU。
总结:正如所说,这个模型没什么特别之处,只是幸运和随机的组合。
获胜模型是非常基础的编码器:自注意力层 -> 前馈层,包含 4 层和每层 8 个头,每个头的键/值维度为 32。基本来自 Tensorflow 教程中的 Transformers 章节
我使用了 atomInSmiles 标记器,但操作不当,导致标记几乎是基于字符的。我没有使用预训练模型(如 ChemBERTa 等)。
可能的成功之处在于两阶段的预训练:
MLM - 15% 掩码率
SMILES 转 ECFP 预测。
推测第二阶段是编码器从 SMILES 中“学习”有意义信息的关键。
外部数据集使用预处理的 “Building Block-Based Binding Predictions for DNA-Encoded Libraries” 数据集。(高相关性数据加入训练可提升最终性能)
潜在可提升方面:
- 复杂的标记方案:双元、三元标记,atomInSmiles
- 深度超过 32 且编码层数超过 6 的模型
- 多输入模型(SMILES + 指纹)
- 使用更大数据集进行预训练 - 在 ZINC 上进行了大约一个月的实验
- 自定义损失函数 - BinaryFocusLoss 效果良好
- 构建块的门控融合
完整代码
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from keras.layers import Dense, LayerNormalization, Add, MultiHeadAttention, Dropout
from keras.layers import Embedding, TextVectorization, GlobalAvgPool1D
from keras import Sequential
from sklearn.model_selection import train_test_split
import os
from typing import Union
import mapply
from skfp.fingerprints import ECFPFingerprint
from rdkit import Chem
import atomInSmiles
import dask.dataframe as dd
import pyarrow as pa
import itertools
import einops# LOSSES and METRICS >>>
@keras.saving.register_keras_serializable(package='belka', name='MultiLabelLoss')
class MultiLabelLoss(keras.losses.Loss):
"""
* Macro- or Micro-averaged Weighted Masked Binary Focal loss.
* Dynamic mini-batch class weights "alpha".
* Used for binary multilabel classification.
"""
def __init__(self, epsilon: float, macro: bool, gamma: float = 2.0, nan_mask: int = 2, name='loss', **kwargs):
super(MultiLabelLoss, self).__init__(name=name)
self.epsilon = epsilon
self.gamma = gamma
self.macro = macro
self.nan_mask = nan_mask
def call(self, y_true, y_pred):
# Cast y_true to tf.int32
y_true = tf.cast(y_true, dtype=tf.int32)
# Compute class weights ("alpha"): Inverse of Square Root of Number of Samples
# Compute "alpha" for each label if "macro" = True
# Assign zero class weights to missing classes
# Normalize: sum of sample weights = sample count per label
freq = tf.math.bincount(
arr=tf.transpose(y_true, perm=[1,0]) if self.macro else y_true,
minlength=2, maxlength=2, dtype=tf.float32, axis=-1 if self.macro else 0)
alpha = tf.where(condition=tf.equal(freq, 0.0), x=0.0, y=tf.math.rsqrt(freq))
ax = 1 if self.macro else None
alpha = alpha * tf.reduce_sum(freq, axis=ax, keepdims=True) / tf.reduce_sum(alpha*freq, axis=ax, keepdims=True)
alpha = tf.reduce_sum(alpha * tf.one_hot(y_true, axis=-1, depth=2, dtype=tf.float32), axis=-1)
# Mask and set to zero missing labels
y_true = tf.cast(y_true, tf.float32)
mask = tf.cast(tf.not_equal(y_true, tf.constant(self.nan_mask, tf.float32)), dtype=tf.float32)
y_true = y_true * mask
# Compute loss
y_pred = tf.clip_by_value(y_pred, clip_value_min=self.epsilon, clip_value_max=1.0 - self.epsilon)
pt = tf.add(
tf.multiply(y_true, y_pred),
tf.multiply(tf.subtract(1.0, y_true), tf.subtract(1.0, y_pred)))
loss = - alpha * (1.0 - pt) ** self.gamma * tf.math.log(pt) * mask
ax = 1 if self.macro else None
loss = tf.divide(tf.reduce_sum(loss, axis=ax), tf.reduce_sum(alpha * mask, axis=ax))
loss = tf.reduce_mean(loss)
return loss
def get_config(self):
config = super(MultiLabelLoss, self).get_config()
config.update({
'epsilon': self.epsilon,
'gamma': self.gamma,
'macro': self.macro,
'nan_mask': self.nan_mask})
return config
@classmethod
def from_config(cls, config):
return cls(**config)
@keras.saving.register_keras_serializable(package='belka', name='CategoricalLoss')
class CategoricalLoss(keras.losses.Loss):
"""
Masked Categorical Focal loss.
Dynamic mini-batch class weights ("alpha").
Used for MLM training.
"""
def __init__(self, epsilon: float, mask: int, vocab_size: int, gamma: float = 2.0, name='loss', **kwargs):
super(CategoricalLoss, self).__init__(name=name)
self.epsilon = epsilon
self.gamma = gamma
self.mask = mask
self.vocab_size = vocab_size
def call(self, y_true, y_pred):
# Unpack y_true to masked (y_true) and unmasked (unmasked) arrays
unmasked = y_true[:,:,1]
y_true = y_true[:,:,0]
# Reshape inputs
y_true = einops.rearrange(y_true, 'b l -> (b l)')
y_pred = einops.rearrange(y_pred, 'b l c -> (b l) c')
# Drop non-masked from y_true
mask = tf.not_equal(y_true, self.mask)
y_true = tf.boolean_mask(y_true, mask)
# Compute class weights ("alpha"): Inverse of Square Root of Number of Samples
# Assign zero class weights to missing classes
# Normalize: sum of sample weights = sample count
freq = tf.math.bincount(unmasked, minlength=self.vocab_size, dtype=tf.float32)
freq = tf.concat([tf.zeros(shape=(2,)), freq[2:]], axis=0) # Set frequencies for [PAD], [MASK] = 0
alpha = tf.where(condition=tf.equal(freq, 0.0), x=0.0, y=tf.math.rsqrt(freq))
# Convert y_true to one-hot
# Apply mask to y_pred
y_true = tf.one_hot(y_true, depth=self.vocab_size, axis=-1, dtype=tf.float32)
y_pred = tf.boolean_mask(y_pred, mask, axis=0)
# Compute loss
y_pred = tf.clip_by_value(y_pred, clip_value_min=self.epsilon, clip_value_max=1.0 - self.epsilon)
pt = tf.add(
tf.multiply(y_true, y_pred),
tf.multiply(tf.subtract(1.0, y_true), tf.subtract(1.0, y_pred)))
loss = - alpha * ((1.0 - pt) ** self.gamma) * (y_true * tf.math.log(y_pred))
loss = tf.divide(tf.reduce_sum(loss), tf.reduce_sum(alpha * y_true))
return loss
def get_config(self) -> dict:
config = super(CategoricalLoss, self).get_config()
config.update({
'epsilon': self.epsilon,
'gamma': self.gamma,
'mask': self.mask,
'vocab_size': self.vocab_size})
return config
@classmethod
def from_config(cls, config):
return cls(**config)
@keras.saving.register_keras_serializable(package='belka', name='BinaryLoss')
class BinaryLoss(keras.losses.Loss):
"""
Binary Focal loss.
Used for FPs training.
"""
def __init__(self, name='loss', **kwargs):
super(BinaryLoss, self).__init__(name=name)
self.loss = tf.keras.losses.BinaryFocalCrossentropy()
def call(self, y_true, y_pred):
y_true = tf.cast(y_true, dtype=tf.float32)
y_true = tf.reshape(y_true, shape=(-1, 1))
y_pred = tf.reshape(y_pred, shape=(-1, 1))
loss = self.loss(y_true, y_pred)
return loss
def get_config(self) -> dict:
config = super(BinaryLoss, self).get_config()
return config
@classmethod
def from_config(cls, config):
return cls(**config)
@keras.saving.register_keras_serializable(package='belka', name='MaskedAUC')
class MaskedAUC(keras.metrics.AUC):
def __init__(self, mode: str, mask: int, multi_label: bool, num_labels: Union[int, None], vocab_size: int,
name='auc', **kwargs):
super(MaskedAUC, self).__init__(curve='PR', multi_label=multi_label, num_labels=num_labels, name=name)
self.mode = mode
self.multi_label = multi_label
self.mask = mask
self.num_labels = num_labels
self.vocab_size = vocab_size
def update_state(self, y_true, y_pred, sample_weight=None):
if self.mode == 'mlm':
# Unpack y_true to masked (y_true) and unmasked (unmasked) arrays
unmasked = y_true[:, :, 1]
y_true = y_true[:, :, 0]
# Reshape inputs
y_true = einops.rearrange(y_true, 'b l -> (b l)')
y_pred = einops.rearrange(y_pred, 'b l c -> (b l) c')
# Drop non-masked tokens from y_true
mask = tf.not_equal(y_true, self.mask)
y_true = tf.boolean_mask(y_true, mask)
# Convert y_true to one-hot
# Apply mask to y_pred
y_true = tf.one_hot(y_true, depth=self.vocab_size, axis=-1, dtype=tf.float32)
y_pred = tf.boolean_mask(y_pred, mask, axis=0)
mask = None
elif self.mode == 'clf':
mask = tf.cast(tf.not_equal(y_true, self.mask), dtype=tf.float32)
else:
y_true = tf.reshape(y_true, shape=(-1,1))
y_pred = tf.reshape(y_pred, shape=(-1,1))
mask = tf.ones_like(y_pred, dtype=tf.float32)
# Compute macro-averaged mAP
super().update_state(y_true, y_pred, sample_weight=mask)
def get_config(self) -> dict:
config = super(MaskedAUC, self).get_config()
config.update({
'mode': self.mode,
'multi_label': self.multi_label,
'mask': self.mask,
'num_labels': self.num_labels,
'vocab_size': self.vocab_size})
return config
@classmethod
def from_config(cls, config):
return cls(**config)
# LAYERS >>>
class FPGenerator(tf.keras.layers.Layer):
def __init__(self, name: str = 'fingerprints', **kwargs):
super(FPGenerator, self).__init__(name=name)
self.transformer = ECFPFingerprint(include_chirality=True, n_jobs=-1)
def call(self, inputs, *args, **kwargs):
"""
Get fingerprints given SMILES string.
"""
x = tf.py_function(
func=self.get_fingerprints,
inp=[inputs],
Tout=tf.int8)
return x
def get_fingerprints(self, inputs):
x = inputs.numpy().astype(str)
x = self.transformer.transform(x)
x = tf.constant(x, dtype=tf.int8)
return x
@keras.saving.register_keras_serializable(package='belka', name='Encodings')
class Encodings(keras.layers.Layer):
def __init__(self, depth: int, max_length: int, name: str = 'encodings', **kwargs):
super(Encodings, self).__init__(name=name)
self.depth = depth
self.max_length = max_length
self.encodings = self._pos_encodings(depth=depth, max_length=max_length)
def call(self, inputs, training=False, *args, **kwargs):
scale = tf.ones_like(inputs) * tf.math.sqrt(tf.cast(self.depth, tf.float32))
x = tf.multiply(inputs, scale)
x = tf.add(x, self.encodings[tf.newaxis, :tf.shape(x)[1], :])
return x
@staticmethod
def _pos_encodings(depth: int, max_length: int):
"""
Get positional encodings of shape [max_length, depth]
"""
positions = tf.range(max_length, dtype=tf.float32)[:, tf.newaxis]
idx = tf.range(depth)[tf.newaxis, :]
power = tf.cast(2 * (idx // 2), tf.float32)
power /= tf.cast(depth, tf.float32)
angles = 1. / tf.math.pow(10000., power)
radians = positions * angles
sin = tf.math.sin(radians[:, 0::2])
cos = tf.math.cos(radians[:, 1::2])
encodings = tf.concat([sin, cos], axis=-1)
return encodings
def get_config(self) -> dict:
return {
'depth': self.depth,
'max_length': self.max_length,
'name': self.name}
@classmethod
def from_config(cls, config):
return cls(**config)
@keras.saving.register_keras_serializable(package='belka', name='Embeddings')
class Embeddings(tf.keras.layers.Layer):
def __init__(self, max_length: int, depth: int, input_dim: int, name: str = 'embeddings', **kwargs):
super(Embeddings, self).__init__(name=name)
self.depth = depth
self.max_length = max_length
self.input_dim = input_dim
self.embeddings = Embedding(input_dim=input_dim, output_dim=depth, mask_zero=True)
self.encodings = Encodings(**parameters)
def build(self, input_shape):
self.embeddings.build(input_shape=input_shape)
super().build(input_shape)
def compute_mask(self, *args, **kwargs):
return self.embeddings.compute_mask(*args, **kwargs)
def call(self, inputs, training=False, *args, **kwargs):
x = self.embeddings(inputs)
x = self.encodings(x)
return x
def get_config(self) -> dict:
return {
'depth': self.depth,
'input_dim': self.input_dim,
'max_length': self.max_length,
'name': self.name}
@classmethod
def from_config(cls, config):
return cls(**config)
@keras.saving.register_keras_serializable(package='belka', name='FeedForward')
class FeedForward(tf.keras.layers.Layer):
def __init__(self, activation: str, depth: int, dropout_rate: float, epsilon: float, name: str = 'ffn', **kwargs):
super(FeedForward, self).__init__(name=name)
self.activation = activation
self.depth = depth
self.dropout_rate = dropout_rate
self.epsilon = epsilon
self.norm = LayerNormalization(epsilon=epsilon)
self.dense1 = Dense(units=int(depth * 2), activation=activation)
self.dense2 = Dense(units=depth)
self.dropout = Dropout(rate=dropout_rate)
self.add = Add()
def build(self, input_shape):
super().build(input_shape)
def call(self, inputs, training=False, *args, **kwargs):
x = self.norm(inputs)
x = self.dense1(x)
x = self.dense2(x)
x = self.dropout(x, training=training)
x = self.add([x, inputs])
return x
def get_config(self) -> dict:
return {
'activation': self.activation,
'depth': self.depth,
'dropout_rate': self.dropout_rate,
'epsilon': self.epsilon,
'name': self.name}
@classmethod
def from_config(cls, config):
return cls(**config)
@keras.saving.register_keras_serializable(package='belka', name='SelfAttention')
class SelfAttention(tf.keras.layers.Layer):
"""
* Self-Attention block with PRE-layer normalization
* LayerNorm -> MHA -> Skip connection
"""
def __init__(self, causal: bool, depth: int, dropout_rate: float, epsilon: float, max_length: int, num_heads: int,
name: str = 'self_attention', **kwargs):
super(SelfAttention, self).__init__(name=name)
self.causal = causal
self.depth = depth
self.dropout_rate = dropout_rate
self.epsilon = epsilon
self.max_length = max_length
self.num_heads = num_heads
self.supports_masking = True
self.norm = LayerNormalization(epsilon=epsilon)
self.mha = MultiHeadAttention(num_heads=num_heads, key_dim=depth, dropout=dropout_rate)
self.add = Add()
def build(self, input_shape):
self.mha.build(input_shape=[input_shape, input_shape])
super().build(input_shape)
def call(self, inputs, training=False, *args, **kwargs):
# Compute attention mask
mask = tf.cast(inputs._keras_mask, dtype=tf.float32)
m1 = tf.expand_dims(mask, axis=2)
m2 = tf.expand_dims(mask, axis=1)
mask = tf.cast(tf.linalg.matmul(m1, m2), dtype=tf.bool)
# Compute outputs
x = self.norm(inputs)
x = self.mha(
query=x,
value=x,
use_causal_mask=self.causal,
training=training,
attention_mask=mask)
x = self.add([x, inputs])
return x
def get_config(self) -> dict:
return {
'causal': self.causal,
'depth': self.depth,
'dropout_rate': self.dropout_rate,
'epsilon': self.epsilon,
'max_length': self.max_length,
'name': self.name,
'num_heads': self.num_heads}
@classmethod
def from_config(cls, config):
return cls(**config)
@keras.saving.register_keras_serializable(package='belka', name='EncoderLayer')
class EncoderLayer(tf.keras.layers.Layer):
"""
* Encoder layer with PRE-layer normalization: LayerNorm -> Self-Attention -> LayerNorm -> FeedForward.
"""
def __init__(self, activation: str, depth: int, dropout_rate: float, epsilon: float, max_length: int,
num_heads: int, name: str = 'encoder_layer', **kwargs):
super(EncoderLayer, self).__init__(name=name)
self.activation = activation
self.depth = depth
self.dropout_rate = dropout_rate
self.epsilon = epsilon
self.max_length = max_length
self.num_heads = num_heads
self.supports_masking = True
self.self_attention = SelfAttention(causal=False, depth=depth, dropout_rate=dropout_rate, epsilon=epsilon,
max_length=max_length, num_heads=num_heads)
self.ffn = FeedForward(activation=activation, depth=depth, dropout_rate=dropout_rate, epsilon=epsilon)
def build(self, input_shape):
super().build(input_shape)
def call(self, inputs, training=False, *args, **kwargs):
x = self.self_attention(inputs, training=training)
x = self.ffn(x, training=training)
return x
def get_config(self) -> dict:
return {
'activation': self.activation,
'depth': self.depth,
'dropout_rate': self.dropout_rate,
'epsilon': self.epsilon,
'max_length': self.max_length,
'name': self.name,
'num_heads': self.num_heads}
@classmethod
def from_config(cls, config):
return cls(**config)
# MODELS >>>
@keras.saving.register_keras_serializable(package='belka', name='SingleOutput')
class Belka(tf.keras.Model):
def __init__(self, dropout_rate: float, mode: str, num_layers: int, vocab_size: int, **kwargs):
super(Belka, self).__init__()
# Arguments
self.dropout_rate = dropout_rate
self.num_layers = num_layers
self.vocab_size = vocab_size
self.mode = mode
# Layers
self.embeddings = Embeddings(input_dim=vocab_size, name='smiles_emb', **parameters)
self.encoder = [EncoderLayer(name='encoder_{}'.format(i), **parameters) for i in range(num_layers)]
if mode == 'mlm':
self.head = Dense(units=vocab_size, activation='softmax', name='smiles')
else:
self.head = Sequential([
GlobalAvgPool1D(),
Dropout(dropout_rate),
Dense(units = 3 if mode == 'clf' else 2048, activation='sigmoid')])
def call(self, inputs, training=False, *args, **kwargs):
x = self.embeddings(inputs, training=training)
for encoder in self.encoder:
x = encoder(x, training=training)
x = self.head(x, training=training)
return x
def get_config(self) -> dict:
return {
'mode': self.mode,
'num_layers': self.num_layers,
'vocab_size': self.vocab_size}
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
# DATASETS >>>
def train_val_set(batch_size: int, buffer_size: int, masking_rate: float, max_length: int, mode: str, seed: int,
vocab_size: int,working: str, **kwargs) -> tuple:
"""
Make train and validation datasets.
"""
# Constants
auto = tf.data.AUTOTUNE
encoder = get_smiles_encoder(**parameters)
# Helper functions >>>
def encode_smiles(x, **kwargs) -> dict:
"""
Encode SMILES strings.
"""
x['smiles'] = tf.io.parse_tensor(x['smiles'], out_type=tf.string)
x['smiles'] = tf.cast(encoder(x['smiles']), dtype=tf.int32)
return x
def get_model_inputs(x) -> tuple:
"""
MLM mode: mask [mask_rate] of non-zero tokens.
* 80% of the time: Replace with the [MSK].
* 10% of the time: Replace with a random token.
* 10% of the time: Keep the token unchanged.
"""
if mode == 'mlm':
# Get paddings mask (0 = [PAD])
paddings_mask = tf.cast(x['smiles'] != 0, dtype=tf.float32)
# Get random mask (1 = [MASK])
probs = tf.stack([1.0 - masking_rate, masking_rate * 0.8, masking_rate * 0.1, masking_rate * 0.1], axis=0)
probs = tf.expand_dims(probs, axis=0)
probs = tf.ones_like(x['smiles'], dtype=tf.float32)[:, :4] * probs
probs = tf.math.log(probs)
mask = tf.multiply(
tf.one_hot(
indices=tf.random.categorical(logits=probs, num_samples=tf.shape(x['smiles'])[-1]),
depth=4,
dtype=tf.float32,
seed=seed),
tf.expand_dims(paddings_mask, axis=-1))
mask = tf.cast(mask, dtype=tf.int32)
# Compute masked inputs
x['masked'] = tf.multiply(
mask,
tf.stack(
values=[
x['smiles'],
tf.ones_like(x['smiles']),
tf.random.uniform(shape=tf.shape(x['smiles']), minval=2, maxval=vocab_size + 1, dtype=tf.int32),
x['smiles']],
axis=-1))
x['masked'] = tf.reduce_sum(x['masked'], axis=-1)
mask = tf.reduce_sum(mask[:, :, 1:], axis=-1)
# Set non-masked values to -1
x['smiles'] = tf.stack(
values=[(x['smiles'] * mask) - (1 - mask), x['smiles']],
axis=-1)
return x['masked'], x['smiles']
elif mode == 'fps':
return x['smiles'], x['ecfp']
else:
return x['smiles'], x['binds']
# Read subsets
encoder = get_smiles_encoder(**parameters)
padded_shapes = {'smiles': (max_length,), 'ecfp': (2048,), 'binds': (3,)}
train = tf.data.Dataset.load(os.path.join(working, 'belka.tfr'), compression='GZIP')
if mode == 'mlm':
features = ['smiles']
subsets = {
'train': train,
'none': None}
elif mode == 'fps':
features = ['smiles', 'ecfp']
subsets = {
'train': train.filter(lambda x: tf.not_equal(x['subset'], 1)),
'val': tf.data.Dataset.load(os.path.join(working, 'belka_val.tfr'), compression='GZIP')}
else:
features = ['smiles', 'binds']
subsets = {
'train': train.filter(lambda x: tf.equal(x['subset'], 0)),
'val': tf.data.Dataset.load(os.path.join(working, 'belka_val.tfr'), compression='GZIP')}
# Preprocess subsets: Cache -> [Repeat -> Shuffle] -> Encode SMILES -> Batch -> Get inputs
for key in [key for key in subsets.keys() if key != 'none']:
subset = subsets[key].map(lambda x: {key: x[key] for key in features}, num_parallel_calls=auto)
subset = subset.cache()
if key == 'train':
subset = subset.repeat().shuffle(buffer_size=buffer_size)
subset = subset.map(lambda x: encode_smiles(x), num_parallel_calls=auto)
subset = subset.padded_batch(batch_size=batch_size, padded_shapes={
key: padded_shapes[key] for key in features})
subsets[key] = subset.map(lambda x: get_model_inputs(x), num_parallel_calls=auto)
return subsets['train'], subsets['val']
# TRAIN & SUBMISSION >>>
def train_model(model: Union[str, None], epochs: int, initial_epoch: int, mode: str, model_name: str, patience: int,
steps_per_epoch: int, validation_steps: int, working: str, **kwargs):
"""
Train the model.
"""
# Train/val subsets
train, val = train_val_set(**parameters)
# Build the model
if model is not None:
model = load_model(model)
else:
model = Belka(**parameters)
if mode == 'mlm':
loss = CategoricalLoss(mask=-1, **parameters)
metrics = MaskedAUC(mask=-1, multi_label=False, num_labels=None, **parameters)
elif mode == 'fps':
loss = BinaryLoss(**parameters)
metrics = MaskedAUC(mask=-1, multi_label=False, num_labels=None, **parameters)
else:
loss = MultiLabelLoss(macro=True, **parameters)
metrics = MaskedAUC(mask=2, multi_label=True, num_labels=3, **parameters)
model.compile(optimizer=tf.keras.optimizers.Adam(), loss=loss, metrics=metrics)
# Callbacks
suffix = {
'mlm': '_{epoch:03d}_{loss:.4f}.model.keras',
'fps': '_{epoch:03d}_{auc:.4f}_{val_auc:.4f}.model.keras',
'clf': '_{epoch:03d}_{auc:.4f}_{val_auc:.4f}.model.keras'}
model_saver = keras.callbacks.ModelCheckpoint(
monitor='loss', mode='min',
filepath=os.path.join(working, model_name + suffix[mode]),
save_best_only=False,
save_weights_only=False)
early_stopping = keras.callbacks.EarlyStopping(
monitor='loss', mode='min', patience=patience, restore_best_weights=True)
learning_rate = tf.keras.callbacks.ReduceLROnPlateau(factor=0.5, patience=5, monitor='loss')
callbacks = [model_saver, early_stopping, learning_rate]
# Print model summary
x, y_true = iter(train).get_next()
y_pred = model(x)
print(model.summary())
# Fit the model
validation_steps = None if mode == 'mlm' else validation_steps
model.fit(train, epochs=epochs, initial_epoch=initial_epoch, steps_per_epoch=steps_per_epoch,
validation_data=val, validation_steps=validation_steps, callbacks=callbacks)
return model
def make_submission(batch_size: dict, max_length: int, model: str, working: str, **kwargs) -> None:
"""
Make submission.
"""
# Make train dataset >>>
df = read_parquet(subset='test', **parameters).iloc[:1000]
df['smiles'] = df['smiles'].mapply(lambda x: atomInSmiles.smiles_tokenizer(x))
ds = tf.data.Dataset.from_tensor_slices(
{'smiles': tf.ragged.constant(df['smiles'].tolist())})
# Tokenize -> Zero-pad -> Batch -> Cast
encoder = get_smiles_encoder(**parameters)
ds = ds.map(lambda x: tf.cast(encoder(x['smiles']), dtype=tf.int32))
ds = ds.padded_batch(batch_size=batch_size, padded_shapes=(max_length,))
ds = ds.prefetch(buffer_size=tf.data.AUTOTUNE)
# Make predictions >>>
model = tf.saved_model.load(model)
pred = np.zeros(shape=(0,3), dtype=np.float32)
for batch in ds:
pred = np.concatenate([pred, model.serve(batch)])
print('\r')
# Write predictions to csv
cols = ['BRD4_pred', 'HSA_pred', 'sEH_pred']
df[cols] = pred
cols = [['BRD4', 'BRD4_pred'], ['HSA', 'HSA_pred'], ['sEH', 'sEH_pred']]
df = np.concatenate([df[col].to_numpy() for col in cols], axis=0)
df = pd.DataFrame(data=df, columns=['id', 'binds'])
df = df.dropna().sort_values(by='id').reset_index(drop=True)
df['id'] = df['id'].astype(int)
df.to_csv(os.path.join(working, 'submission.csv'), index=False)
return None
# UTILS >>>
def read_parquet(subset: str, root: str, **kwargs) -> pd.DataFrame:
"""
Read and preprocess train/test parquet files.
"""
# Read train set
df = pd.read_parquet(os.path.join(root, f'{subset}.parquet'))
df = df.rename(columns={
'buildingblock1_smiles': 'block1',
'buildingblock2_smiles': 'block2',
'buildingblock3_smiles': 'block3',
'molecule_smiles': 'smiles'})
# Group by molecule -> get multiclass labels
cols = ['block1', 'block2', 'block3', 'smiles']
values = 'binds' if subset == 'train' else 'id'
df = df.pivot(index=cols, columns='protein_name', values=values).reset_index()
return df
def make_parquet(working: str, seed: int, **kwargs) -> None:
"""
Make Dask DataFrame:
* Read and shuffle dataframe.
* Stack labels binding affinity [BRD4, HSA, sEH]. Nan mask = 2.
* Validation split (at least one non-shared blocks).
* Add train/validation/test indicator (0/1/2).
* Replace [Dy] DNA-linker with [H]
* Get ECFPs
* Write to parquet file.
Source of extra data: https://chemrxiv.org/engage/chemrxiv/article-details/6438943f08c86922ffeffe57
Processed by: @chemdatafarmer @Hengck23
"""
def validation_split(x, test: set):
"""
Get train (0), or validation (1) indicators.
Train subset: zero-intersection between blocks and "test" blocks.
Validation subset: non-zero-intersection between blocks and "test" blocks.
"""
blocks = set(x[col] for col in ['block1', 'block2', 'block3'])
i = len(blocks.intersection(test))
i = 0 if i == 0 else 1
i = np.int8(i)
return i
def replace_linker(smiles):
"""
Replace [Dy] linker with hydrogen.
"""
smiles = smiles.replace('[Dy]', '[H]')
smiles = Chem.CanonSmiles(smiles)
return smiles
# Iterate over subsets
dataset = []
for subset in ['test', 'extra', 'train']:
# Read parquet/csv
if subset in ['train', 'test']:
df = read_parquet(subset=subset, **parameters)
else:
df = pd.read_csv(os.path.join(working, 'DNA_Labeled_Data.csv'), usecols=['new_structure', 'read_count'])
df = df.rename(columns={'new_structure': 'smiles', 'read_count': 'binds'})
# Stack binding affinity labels
cols = ['BRD4', 'HSA', 'sEH']
if subset == 'train':
df['binds'] = np.stack([df[col].to_numpy() for col in cols], axis=-1, dtype=np.int8).tolist()
elif subset == 'test':
df["binds"] = np.tile(np.array([[2, 2, 2]], dtype=np.int8), reps=(df.shape[0], 1)).tolist()
else:
df['binds'] = df['binds'].mapply(lambda x: [2, 2, np.clip(x, a_min=0, a_max=1)])
for col in cols:
df = df.drop(columns=[col]) if col in df.columns else df
# Validation split
if subset == 'train':
blocks = list(set(df['block1'].to_list()) | set(df['block2'].tolist()) | set(df['block3'].tolist()))
_, val, _, _ = train_test_split(blocks, blocks, test_size=0.03, random_state=seed)
df['subset'] = df.mapply(lambda x: validation_split(x, test=val), axis=1)
elif subset == 'test':
df['subset'] = 2
else:
df['subset'] = 0 # Use extra data only for training
cols = ['block1', 'block2', 'block3']
for col in cols:
df = df.drop(columns=[col]) if col in df.columns else df
# Replace [Dy] DNA-linker with [H]
df['smiles_no_linker'] = df['smiles'].mapply(lambda x: replace_linker(smiles=x))
# Append dataframe to list
dataset.append(df)
# Concatenate -> Shuffle -> Convert to Dask Dataframe
df = pd.concat(dataset)
df = df.sample(frac=1.0, ignore_index=True, random_state=seed)
df = dd.from_pandas(data=df)
df = df.repartition(npartitions=20)
# Write to parquet
df.to_parquet(os.path.join(working, 'belka.parquet'), schema={
'smiles': pa.string(),
'binds': pa.list_(pa.int8(), 3),
'subset': pa.int8(),
'smiles_no_linker': pa.string()})
return None
def make_dataset(working: str, **kwargs) -> None:
"""
Make TFR dataset.
"""
def generator() -> dict:
for row in df.itertuples(index=False, name='Row'):
yield {
'smiles': row.smiles,
'smiles_no_linker': row.smiles_no_linker,
'binds': row.binds,
'subset': row.subset}
def serialize_smiles(x) -> dict:
"""
Serialize smiles to string.
"""
x['smiles'] = tf.io.serialize_tensor(x['smiles'])
return x
def get_ecfp(x) -> dict:
"""
Compute ECFP form "smiles_no_linker".
"""
x['ecfp'] = transformer(x['smiles_no_linker'])
x.pop('smiles_no_linker')
return x
# Read dataset
df = dd.read_parquet(os.path.join(working, 'belka.parquet'))
df = df.compute()
# Tokenize SMILES
df['smiles'] = df['smiles'].mapply(lambda x: atomInSmiles.smiles_tokenizer(x))
# Write to TFRecords
auto = tf.data.AUTOTUNE
transformer = FPGenerator()
ds = tf.data.Dataset.from_generator(
generator=lambda : generator(),
output_signature={
'smiles': tf.TensorSpec(shape=(None,), dtype=tf.string),
'smiles_no_linker': tf.TensorSpec(shape=(), dtype=tf.string),
'binds': tf.TensorSpec(shape=(3,), dtype=tf.int8),
'subset': tf.TensorSpec(shape=(), dtype=tf.int8)})
ds = ds.map(lambda x: serialize_smiles(x))
ds = ds.batch(batch_size=1024, num_parallel_calls=auto)
ds = ds.map(lambda x: get_ecfp(x), num_parallel_calls=auto)
ds = ds.unbatch()
ds.save(os.path.join(working, 'belka.tfr'), compression='GZIP')
return None
def get_vocab(working: str, **kwargs) -> None:
"""
Get vocabulary for SMILES encoding.
"""
# Read parquet
df = dd.read_parquet(os.path.join(working, 'belka.parquet'))
df = df.compute()
# Tokenize SMILES -> Get list if unique tokens
df['smiles'] = df['smiles'].mapply(lambda x: list(set(atomInSmiles.smiles_tokenizer(x))))
vocab = np.unique(list(itertools.chain.from_iterable(df['smiles'].tolist()))).tolist()
vocab = pd.DataFrame(data=vocab)
vocab.to_csv(os.path.join(working, 'vocab.txt'), index=False, header=False)
return None
def get_smiles_encoder(vocab: str, **kwargs) -> TextVectorization:
"""
Get TextVectorization SMILES encoder.
"""
tokenizer = TextVectorization(
standardize=None,
split=None,
vocabulary=vocab)
return tokenizer
def load_model(model: str, **kwargs) -> tf.keras.Model:
"""
Load the model.
"""
model = tf.keras.models.load_model(model, compile=True, custom_objects={
'Encodings': Encodings,
'Embeddings': Embeddings,
'FeedForward': FeedForward,
'SelfAttention': SelfAttention,
'EncoderLayer': EncoderLayer,
'MultiLabelLoss': MultiLabelLoss,
'CategoricalLoss': CategoricalLoss,
'BinaryLoss': BinaryLoss,
'MaskedAUC': MaskedAUC})
return model
def set_parameters(
activation: str, batch_size: int, buffer_size: Union[int, float],
depth: int,
dropout_rate: float, epochs: int, epsilon: float, initial_epoch: int,
masking_rate: float, max_length: int, mode: str, model: Union[str, None],
model_name: str, num_heads: int, num_layers: int,
patience: int, root: str, seed: int, steps_per_epoch: int, validation_steps: int, vocab: str, vocab_size: int,
working: str) -> dict:
"""
Set uniform parameters for the functions in the scope of the project.
:param mode: Choose from ['clf', 'fps', 'mlm'],
:param vocab_size: Set to N+2, where N - number of tokens ([PAD] = 0, [MASK] = 1). :return:
"""
inputs = {
'activation': activation,
'batch_size': batch_size,
'buffer_size': int(buffer_size),
'depth': depth,
'dropout_rate': dropout_rate,
'epochs': epochs,
'epsilon': epsilon,
'initial_epoch': initial_epoch,
'masking_rate': masking_rate,
'max_length': max_length,
'mode': mode,
'model': model,
'model_name': model_name,
'num_heads': num_heads,
'num_layers': num_layers,
'patience': patience,
'root': root,
'seed': seed,
'steps_per_epoch': steps_per_epoch,
'validation_steps': validation_steps,
'vocab': vocab,
'vocab_size': vocab_size,
'working': working}
return inputsmapply.init(n_workers=-1, progressbar=True)
parameters = set_parameters(
root='leash-BELKA',
working='working/',
vocab='vocab.txt',
model=None,
mode='clf',
model_name='belka',
masking_rate=0.15,
batch_size=2048, buffer_size=1e07,
epochs=1000, initial_epoch=0, steps_per_epoch=10000, validation_steps=2000,
max_length=128, vocab_size=43,
depth=32, dropout_rate=0.1, num_heads=8, num_layers=4, activation='gelu',
patience=20, epsilon=1e-07, seed=42)