FakeOrange
预计阅读时间:43分钟56秒

小分子与蛋白质相互作用深度学习训练流程

以较低成本的方式验证模型训练方案的可行性

0
0

前言


本文为NeurIPS小分子蛋白深度学习预测训练方案,使用BELKA数据库。经验证是目前公开方案中在具有时限与执行要求标准下的最佳实践方案。



模型结构与标记


模型架构相对简单,模型结构较为扁平——只有 4 层编码器,每层 8 个头。由于词汇表大小仅为 43 个标记,因此维度固定为 32。我也尝试过 64 和 16,但效果不佳。

可能以一种不太准确的方式使用了 atomInSmiles,最终形成了一个独特的标记方案:每个标记分别为原子(如 C, H, S 等)、数字或方括号中的内容(例如 [C@@])。化学专业人士可以更好地理解这些标记代表的含义。



预训练


在两个阶段从头开始预训练模型:

  • MLM:通过标准的掩码标记预测(15% 标记被掩码,其中 80% 被掩盖,10% 被替换为随机标记,10% 保留原样,“BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding”中有相同方法)。我使用了动态小批量 alpha 权重进行 CategoricalFocalCrossEntropy,但仅为实验效果,不确定对结果有无贡献。训练大约 100 个 epoch,每个 epoch 10000 步,每批次 2028 样本,模型共处理数据约 20 次。训练数据包括所有数据——训练集、测试集及外部数据(具体见下文参考)。


  • SMILES 转 ECFP(大小=2048,包含手性)。相同模型,只是更换为不同的输出层(使用 Sigmoid 激活的 Dense 层)并锁定嵌入。训练了 20-50 个 epoch。模型表现一般(有很多研究表明,SMILES 编码器难以预测拓扑指纹,这也是该模型的情况),MAP 约为 0.4,但模型确实在此过程中学到了一些有用的表示。


选择 ECFP 主要因为两点:(a) 性能——计算速度较快,尤其是结合 scikit-fingerprints 库,(b) 预测指纹位没有预定义含义(不同于 MACCS 或 PubChem),这是对 SMILES 转换器的挑战,从而有可能获得更好的特征表示。


训练:结合 BELKA 训练集和外部数据。由于外部数据仅包含 sEH 蛋白的标签,模型使用掩码损失和度量。

验证:从训练集中留出 3% 的区块,确保验证集包含至少一个不共享的区块,共约 900 万样本。

设备:A100 GPU。

总结:正如所说,这个模型没什么特别之处,只是幸运和随机的组合。


获胜模型是非常基础的编码器:自注意力层 -> 前馈层,包含 4 层和每层 8 个头,每个头的键/值维度为 32。基本来自 Tensorflow 教程中的 Transformers 章节


我使用了 atomInSmiles 标记器,但操作不当,导致标记几乎是基于字符的。我没有使用预训练模型(如 ChemBERTa 等)。


可能的成功之处在于两阶段的预训练:

MLM - 15% 掩码率

SMILES 转 ECFP 预测。

推测第二阶段是编码器从 SMILES 中“学习”有意义信息的关键。

外部数据集使用预处理的 “Building Block-Based Binding Predictions for DNA-Encoded Libraries” 数据集。(高相关性数据加入训练可提升最终性能)



潜在可提升方面:


  1. 复杂的标记方案:双元、三元标记,atomInSmiles
  2. 深度超过 32 且编码层数超过 6 的模型
  3. 多输入模型(SMILES + 指纹)
  4. 使用更大数据集进行预训练 - 在 ZINC 上进行了大约一个月的实验
  5. 自定义损失函数 - BinaryFocusLoss 效果良好
  6. 构建块的门控融合


完整代码


import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from keras.layers import Dense, LayerNormalization, Add, MultiHeadAttention, Dropout
from keras.layers import Embedding, TextVectorization, GlobalAvgPool1D
from keras import Sequential
from sklearn.model_selection import train_test_split
import os
from typing import Union
import mapply
from skfp.fingerprints import ECFPFingerprint
from rdkit import Chem
import atomInSmiles
import dask.dataframe as dd
import pyarrow as pa
import itertools
import einops
# LOSSES and METRICS >>>

@keras.saving.register_keras_serializable(package='belka', name='MultiLabelLoss')
class MultiLabelLoss(keras.losses.Loss):
    """
    * Macro- or Micro-averaged Weighted Masked Binary Focal loss.
    * Dynamic mini-batch class weights "alpha".
    * Used for binary multilabel classification.
    """

    def __init__(self, epsilon: float, macro: bool, gamma: float = 2.0, nan_mask: int = 2, name='loss', **kwargs):
        super(MultiLabelLoss, self).__init__(name=name)
        self.epsilon = epsilon
        self.gamma = gamma
        self.macro = macro
        self.nan_mask = nan_mask

    def call(self, y_true, y_pred):

        # Cast y_true to tf.int32
        y_true = tf.cast(y_true, dtype=tf.int32)

        # Compute class weights ("alpha"): Inverse of Square Root of Number of Samples
        # Compute "alpha" for each label if "macro" = True
        # Assign zero class weights to missing classes
        # Normalize: sum of sample weights = sample count per label
        freq = tf.math.bincount(
            arr=tf.transpose(y_true, perm=[1,0]) if self.macro else y_true,
            minlength=2, maxlength=2, dtype=tf.float32, axis=-1 if self.macro else 0)
        alpha = tf.where(condition=tf.equal(freq, 0.0), x=0.0, y=tf.math.rsqrt(freq))
        ax = 1 if self.macro else None
        alpha = alpha * tf.reduce_sum(freq, axis=ax, keepdims=True) / tf.reduce_sum(alpha*freq, axis=ax, keepdims=True)
        alpha = tf.reduce_sum(alpha * tf.one_hot(y_true, axis=-1, depth=2, dtype=tf.float32), axis=-1)

        # Mask and set to zero missing labels
        y_true = tf.cast(y_true, tf.float32)
        mask = tf.cast(tf.not_equal(y_true, tf.constant(self.nan_mask, tf.float32)), dtype=tf.float32)
        y_true = y_true * mask

        # Compute loss
        y_pred = tf.clip_by_value(y_pred, clip_value_min=self.epsilon, clip_value_max=1.0 - self.epsilon)
        pt = tf.add(
            tf.multiply(y_true, y_pred),
            tf.multiply(tf.subtract(1.0, y_true), tf.subtract(1.0, y_pred)))
        loss = - alpha * (1.0 - pt) ** self.gamma * tf.math.log(pt) * mask
        ax = 1 if self.macro else None
        loss = tf.divide(tf.reduce_sum(loss, axis=ax), tf.reduce_sum(alpha * mask, axis=ax))
        loss = tf.reduce_mean(loss)
        return loss

    def get_config(self):
        config = super(MultiLabelLoss, self).get_config()
        config.update({
            'epsilon': self.epsilon,
            'gamma': self.gamma,
            'macro': self.macro,
            'nan_mask': self.nan_mask})
        return config

    @classmethod
    def from_config(cls, config):
        return cls(**config)


@keras.saving.register_keras_serializable(package='belka', name='CategoricalLoss')
class CategoricalLoss(keras.losses.Loss):
    """
    Masked Categorical Focal loss.
    Dynamic mini-batch class weights ("alpha").
    Used for MLM training.
    """
    def __init__(self, epsilon: float, mask: int, vocab_size: int, gamma: float = 2.0, name='loss', **kwargs):
        super(CategoricalLoss, self).__init__(name=name)
        self.epsilon = epsilon
        self.gamma = gamma
        self.mask = mask
        self.vocab_size = vocab_size

    def call(self, y_true, y_pred):

        # Unpack y_true to masked (y_true) and unmasked (unmasked) arrays
        unmasked = y_true[:,:,1]
        y_true = y_true[:,:,0]

        # Reshape inputs
        y_true = einops.rearrange(y_true, 'b l -> (b l)')
        y_pred = einops.rearrange(y_pred, 'b l c -> (b l) c')

        # Drop non-masked from y_true
        mask = tf.not_equal(y_true, self.mask)
        y_true = tf.boolean_mask(y_true, mask)

        # Compute class weights ("alpha"): Inverse of Square Root of Number of Samples
        # Assign zero class weights to missing classes
        # Normalize: sum of sample weights = sample count
        freq = tf.math.bincount(unmasked, minlength=self.vocab_size, dtype=tf.float32)
        freq = tf.concat([tf.zeros(shape=(2,)), freq[2:]], axis=0)  # Set frequencies for [PAD], [MASK] = 0
        alpha = tf.where(condition=tf.equal(freq, 0.0), x=0.0, y=tf.math.rsqrt(freq))

        # Convert y_true to one-hot
        # Apply mask to y_pred
        y_true = tf.one_hot(y_true, depth=self.vocab_size, axis=-1, dtype=tf.float32)
        y_pred = tf.boolean_mask(y_pred, mask, axis=0)

        # Compute loss
        y_pred = tf.clip_by_value(y_pred, clip_value_min=self.epsilon, clip_value_max=1.0 - self.epsilon)
        pt = tf.add(
            tf.multiply(y_true, y_pred),
            tf.multiply(tf.subtract(1.0, y_true), tf.subtract(1.0, y_pred)))
        loss = - alpha * ((1.0 - pt) ** self.gamma) * (y_true * tf.math.log(y_pred))
        loss = tf.divide(tf.reduce_sum(loss), tf.reduce_sum(alpha * y_true))
        return loss

    def get_config(self) -> dict:
        config = super(CategoricalLoss, self).get_config()
        config.update({
            'epsilon': self.epsilon,
            'gamma': self.gamma,
            'mask': self.mask,
            'vocab_size': self.vocab_size})
        return config

    @classmethod
    def from_config(cls, config):
        return cls(**config)


@keras.saving.register_keras_serializable(package='belka', name='BinaryLoss')
class BinaryLoss(keras.losses.Loss):
    """
    Binary Focal loss.
    Used for FPs training.
    """
    def __init__(self, name='loss', **kwargs):
        super(BinaryLoss, self).__init__(name=name)
        self.loss = tf.keras.losses.BinaryFocalCrossentropy()

    def call(self, y_true, y_pred):
        y_true = tf.cast(y_true, dtype=tf.float32)
        y_true = tf.reshape(y_true, shape=(-1, 1))
        y_pred = tf.reshape(y_pred, shape=(-1, 1))
        loss = self.loss(y_true, y_pred)
        return loss

    def get_config(self) -> dict:
        config = super(BinaryLoss, self).get_config()
        return config

    @classmethod
    def from_config(cls, config):
        return cls(**config)


@keras.saving.register_keras_serializable(package='belka', name='MaskedAUC')
class MaskedAUC(keras.metrics.AUC):
    def __init__(self, mode: str, mask: int, multi_label: bool, num_labels: Union[int, None], vocab_size: int,
                 name='auc', **kwargs):
        super(MaskedAUC, self).__init__(curve='PR', multi_label=multi_label, num_labels=num_labels, name=name)
        self.mode = mode
        self.multi_label = multi_label
        self.mask = mask
        self.num_labels = num_labels
        self.vocab_size = vocab_size

    def update_state(self, y_true, y_pred, sample_weight=None):

        if self.mode == 'mlm':

            # Unpack y_true to masked (y_true) and unmasked (unmasked) arrays
            unmasked = y_true[:, :, 1]
            y_true = y_true[:, :, 0]

            # Reshape inputs
            y_true = einops.rearrange(y_true, 'b l -> (b l)')
            y_pred = einops.rearrange(y_pred, 'b l c -> (b l) c')

            # Drop non-masked tokens from y_true
            mask = tf.not_equal(y_true, self.mask)
            y_true = tf.boolean_mask(y_true, mask)

            # Convert y_true to one-hot
            # Apply mask to y_pred
            y_true = tf.one_hot(y_true, depth=self.vocab_size, axis=-1, dtype=tf.float32)
            y_pred = tf.boolean_mask(y_pred, mask, axis=0)
            mask = None

        elif self.mode == 'clf':
            mask = tf.cast(tf.not_equal(y_true, self.mask), dtype=tf.float32)

        else:
            y_true = tf.reshape(y_true, shape=(-1,1))
            y_pred = tf.reshape(y_pred, shape=(-1,1))
            mask = tf.ones_like(y_pred, dtype=tf.float32)

        # Compute macro-averaged mAP
        super().update_state(y_true, y_pred, sample_weight=mask)

    def get_config(self) -> dict:
        config = super(MaskedAUC, self).get_config()
        config.update({
            'mode': self.mode,
            'multi_label': self.multi_label,
            'mask': self.mask,
            'num_labels': self.num_labels,
            'vocab_size': self.vocab_size})
        return config

    @classmethod
    def from_config(cls, config):
        return cls(**config)


# LAYERS >>>

class FPGenerator(tf.keras.layers.Layer):
    def __init__(self, name: str = 'fingerprints', **kwargs):
        super(FPGenerator, self).__init__(name=name)
        self.transformer = ECFPFingerprint(include_chirality=True, n_jobs=-1)

    def call(self, inputs, *args, **kwargs):
        """
        Get fingerprints given SMILES string.
        """
        x = tf.py_function(
            func=self.get_fingerprints,
            inp=[inputs],
            Tout=tf.int8)
        return x

    def get_fingerprints(self, inputs):
        x = inputs.numpy().astype(str)
        x = self.transformer.transform(x)
        x = tf.constant(x, dtype=tf.int8)
        return x


@keras.saving.register_keras_serializable(package='belka', name='Encodings')
class Encodings(keras.layers.Layer):
    def __init__(self, depth: int, max_length: int, name: str = 'encodings', **kwargs):
        super(Encodings, self).__init__(name=name)
        self.depth = depth
        self.max_length = max_length
        self.encodings = self._pos_encodings(depth=depth, max_length=max_length)

    def call(self, inputs, training=False, *args, **kwargs):
        scale = tf.ones_like(inputs) * tf.math.sqrt(tf.cast(self.depth, tf.float32))
        x = tf.multiply(inputs, scale)
        x = tf.add(x, self.encodings[tf.newaxis, :tf.shape(x)[1], :])
        return x

    @staticmethod
    def _pos_encodings(depth: int, max_length: int):
        """
        Get positional encodings of shape [max_length, depth]
        """
        positions = tf.range(max_length, dtype=tf.float32)[:, tf.newaxis]
        idx = tf.range(depth)[tf.newaxis, :]
        power = tf.cast(2 * (idx // 2), tf.float32)
        power /= tf.cast(depth, tf.float32)
        angles = 1. / tf.math.pow(10000., power)
        radians = positions * angles
        sin = tf.math.sin(radians[:, 0::2])
        cos = tf.math.cos(radians[:, 1::2])
        encodings = tf.concat([sin, cos], axis=-1)
        return encodings

    def get_config(self) -> dict:
        return {
            'depth': self.depth,
            'max_length': self.max_length,
            'name': self.name}

    @classmethod
    def from_config(cls, config):
        return cls(**config)


@keras.saving.register_keras_serializable(package='belka', name='Embeddings')
class Embeddings(tf.keras.layers.Layer):
    def __init__(self, max_length: int, depth: int, input_dim: int, name: str = 'embeddings', **kwargs):
        super(Embeddings, self).__init__(name=name)
        self.depth = depth
        self.max_length = max_length
        self.input_dim = input_dim
        self.embeddings = Embedding(input_dim=input_dim, output_dim=depth, mask_zero=True)
        self.encodings = Encodings(**parameters)

    def build(self, input_shape):
        self.embeddings.build(input_shape=input_shape)
        super().build(input_shape)

    def compute_mask(self, *args, **kwargs):
        return self.embeddings.compute_mask(*args, **kwargs)

    def call(self, inputs, training=False, *args, **kwargs):
        x = self.embeddings(inputs)
        x = self.encodings(x)
        return x

    def get_config(self) -> dict:
        return {
            'depth': self.depth,
            'input_dim': self.input_dim,
            'max_length': self.max_length,
            'name': self.name}

    @classmethod
    def from_config(cls, config):
        return cls(**config)


@keras.saving.register_keras_serializable(package='belka', name='FeedForward')
class FeedForward(tf.keras.layers.Layer):
    def __init__(self, activation: str, depth: int, dropout_rate: float, epsilon: float, name: str = 'ffn', **kwargs):
        super(FeedForward, self).__init__(name=name)
        self.activation = activation
        self.depth = depth
        self.dropout_rate = dropout_rate
        self.epsilon = epsilon
        self.norm = LayerNormalization(epsilon=epsilon)
        self.dense1 = Dense(units=int(depth * 2), activation=activation)
        self.dense2 = Dense(units=depth)
        self.dropout = Dropout(rate=dropout_rate)
        self.add = Add()

    def build(self, input_shape):
        super().build(input_shape)

    def call(self, inputs, training=False, *args, **kwargs):
        x = self.norm(inputs)
        x = self.dense1(x)
        x = self.dense2(x)
        x = self.dropout(x, training=training)
        x = self.add([x, inputs])
        return x

    def get_config(self) -> dict:
        return {
            'activation': self.activation,
            'depth': self.depth,
            'dropout_rate': self.dropout_rate,
            'epsilon': self.epsilon,
            'name': self.name}


    @classmethod
    def from_config(cls, config):
        return cls(**config)


@keras.saving.register_keras_serializable(package='belka', name='SelfAttention')
class SelfAttention(tf.keras.layers.Layer):
    """
    * Self-Attention block with PRE-layer normalization
    * LayerNorm -> MHA -> Skip connection
    """
    def __init__(self, causal: bool, depth: int, dropout_rate: float, epsilon: float, max_length: int, num_heads: int,
                 name: str = 'self_attention', **kwargs):
        super(SelfAttention, self).__init__(name=name)
        self.causal = causal
        self.depth = depth
        self.dropout_rate = dropout_rate
        self.epsilon = epsilon
        self.max_length = max_length
        self.num_heads = num_heads
        self.supports_masking = True
        self.norm = LayerNormalization(epsilon=epsilon)
        self.mha = MultiHeadAttention(num_heads=num_heads, key_dim=depth, dropout=dropout_rate)
        self.add = Add()

    def build(self, input_shape):
        self.mha.build(input_shape=[input_shape, input_shape])
        super().build(input_shape)

    def call(self, inputs, training=False, *args, **kwargs):

        # Compute attention mask
        mask = tf.cast(inputs._keras_mask, dtype=tf.float32)
        m1 = tf.expand_dims(mask, axis=2)
        m2 = tf.expand_dims(mask, axis=1)
        mask = tf.cast(tf.linalg.matmul(m1, m2), dtype=tf.bool)

        # Compute outputs
        x = self.norm(inputs)
        x = self.mha(
            query=x,
            value=x,
            use_causal_mask=self.causal,
            training=training,
            attention_mask=mask)
        x = self.add([x, inputs])

        return x

    def get_config(self) -> dict:
        return {
            'causal': self.causal,
            'depth': self.depth,
            'dropout_rate': self.dropout_rate,
            'epsilon': self.epsilon,
            'max_length': self.max_length,
            'name': self.name,
            'num_heads': self.num_heads}

    @classmethod
    def from_config(cls, config):
        return cls(**config)


@keras.saving.register_keras_serializable(package='belka', name='EncoderLayer')
class EncoderLayer(tf.keras.layers.Layer):
    """
    * Encoder layer with PRE-layer normalization: LayerNorm -> Self-Attention -> LayerNorm -> FeedForward.
    """
    def __init__(self, activation: str, depth: int, dropout_rate: float, epsilon: float, max_length: int,
                 num_heads: int, name: str = 'encoder_layer', **kwargs):
        super(EncoderLayer, self).__init__(name=name)
        self.activation = activation
        self.depth = depth
        self.dropout_rate = dropout_rate
        self.epsilon = epsilon
        self.max_length = max_length
        self.num_heads = num_heads
        self.supports_masking = True
        self.self_attention = SelfAttention(causal=False, depth=depth, dropout_rate=dropout_rate, epsilon=epsilon,
                                            max_length=max_length, num_heads=num_heads)
        self.ffn = FeedForward(activation=activation, depth=depth, dropout_rate=dropout_rate, epsilon=epsilon)

    def build(self, input_shape):
        super().build(input_shape)

    def call(self, inputs, training=False, *args, **kwargs):
        x = self.self_attention(inputs, training=training)
        x = self.ffn(x, training=training)
        return x

    def get_config(self) -> dict:
        return {
            'activation': self.activation,
            'depth': self.depth,
            'dropout_rate': self.dropout_rate,
            'epsilon': self.epsilon,
            'max_length': self.max_length,
            'name': self.name,
            'num_heads': self.num_heads}


    @classmethod
    def from_config(cls, config):
        return cls(**config)


# MODELS >>>

@keras.saving.register_keras_serializable(package='belka', name='SingleOutput')
class Belka(tf.keras.Model):
    def __init__(self, dropout_rate: float, mode: str, num_layers: int, vocab_size: int, **kwargs):
        super(Belka, self).__init__()

        # Arguments
        self.dropout_rate = dropout_rate
        self.num_layers = num_layers
        self.vocab_size = vocab_size
        self.mode = mode

        #  Layers
        self.embeddings = Embeddings(input_dim=vocab_size, name='smiles_emb', **parameters)
        self.encoder = [EncoderLayer(name='encoder_{}'.format(i), **parameters) for i in range(num_layers)]
        if mode == 'mlm':
            self.head = Dense(units=vocab_size, activation='softmax', name='smiles')
        else:
            self.head = Sequential([
                GlobalAvgPool1D(),
                Dropout(dropout_rate),
                Dense(units = 3 if mode == 'clf' else 2048, activation='sigmoid')])

    def call(self, inputs, training=False, *args, **kwargs):
        x = self.embeddings(inputs, training=training)
        for encoder in self.encoder:
            x = encoder(x, training=training)
        x = self.head(x, training=training)
        return x

    def get_config(self) -> dict:
        return {
            'mode': self.mode,
            'num_layers': self.num_layers,
            'vocab_size': self.vocab_size}

    @classmethod
    def from_config(cls, config, custom_objects=None):
        return cls(**config)


# DATASETS >>>

def train_val_set(batch_size: int, buffer_size: int, masking_rate: float, max_length: int, mode: str, seed: int,
                  vocab_size: int,working: str, **kwargs) -> tuple:
    """
    Make train and validation datasets.
    """

    # Constants
    auto = tf.data.AUTOTUNE
    encoder = get_smiles_encoder(**parameters)

    # Helper functions >>>

    def encode_smiles(x, **kwargs) -> dict:
        """
        Encode SMILES strings.
        """
        x['smiles'] = tf.io.parse_tensor(x['smiles'], out_type=tf.string)
        x['smiles'] = tf.cast(encoder(x['smiles']), dtype=tf.int32)
        return x

    def get_model_inputs(x) -> tuple:
        """
        MLM mode: mask [mask_rate] of non-zero tokens.
            * 80% of the time: Replace with the [MSK].
            * 10% of the time: Replace with a random token.
            * 10% of the time: Keep the token unchanged.
        """

        if mode == 'mlm':

            # Get paddings mask (0 = [PAD])
            paddings_mask = tf.cast(x['smiles'] != 0, dtype=tf.float32)

            # Get random mask (1 = [MASK])
            probs = tf.stack([1.0 - masking_rate, masking_rate * 0.8, masking_rate * 0.1, masking_rate * 0.1], axis=0)
            probs = tf.expand_dims(probs, axis=0)
            probs = tf.ones_like(x['smiles'], dtype=tf.float32)[:, :4] * probs
            probs = tf.math.log(probs)
            mask = tf.multiply(
                tf.one_hot(
                    indices=tf.random.categorical(logits=probs, num_samples=tf.shape(x['smiles'])[-1]),
                    depth=4,
                    dtype=tf.float32,
                    seed=seed),
                tf.expand_dims(paddings_mask, axis=-1))
            mask = tf.cast(mask, dtype=tf.int32)

            # Compute masked inputs
            x['masked'] = tf.multiply(
                mask,
                tf.stack(
                    values=[
                        x['smiles'],
                        tf.ones_like(x['smiles']),
                        tf.random.uniform(shape=tf.shape(x['smiles']), minval=2, maxval=vocab_size + 1, dtype=tf.int32),
                        x['smiles']],
                    axis=-1))
            x['masked'] = tf.reduce_sum(x['masked'], axis=-1)
            mask = tf.reduce_sum(mask[:, :, 1:], axis=-1)

            # Set non-masked values to -1
            x['smiles'] = tf.stack(
                values=[(x['smiles'] * mask) - (1 - mask), x['smiles']],
                axis=-1)

            return x['masked'], x['smiles']

        elif mode == 'fps':
            return x['smiles'], x['ecfp']

        else:
            return x['smiles'], x['binds']

    # Read subsets
    encoder = get_smiles_encoder(**parameters)
    padded_shapes = {'smiles':  (max_length,), 'ecfp': (2048,), 'binds': (3,)}
    train = tf.data.Dataset.load(os.path.join(working, 'belka.tfr'), compression='GZIP')
    if mode == 'mlm':
        features = ['smiles']
        subsets = {
            'train': train,
            'none': None}
    elif mode == 'fps':
        features = ['smiles', 'ecfp']
        subsets = {
            'train': train.filter(lambda x: tf.not_equal(x['subset'], 1)),
            'val': tf.data.Dataset.load(os.path.join(working, 'belka_val.tfr'), compression='GZIP')}
    else:
        features = ['smiles', 'binds']
        subsets = {
            'train': train.filter(lambda x: tf.equal(x['subset'], 0)),
            'val': tf.data.Dataset.load(os.path.join(working, 'belka_val.tfr'), compression='GZIP')}

    # Preprocess subsets:  Cache -> [Repeat -> Shuffle] -> Encode SMILES -> Batch -> Get inputs
    for key in [key for key in subsets.keys() if key != 'none']:
        subset = subsets[key].map(lambda x: {key: x[key] for key in features}, num_parallel_calls=auto)
        subset = subset.cache()
        if key == 'train':
            subset = subset.repeat().shuffle(buffer_size=buffer_size)
        subset = subset.map(lambda x: encode_smiles(x), num_parallel_calls=auto)
        subset = subset.padded_batch(batch_size=batch_size, padded_shapes={
            key: padded_shapes[key] for key in features})
        subsets[key] = subset.map(lambda x: get_model_inputs(x), num_parallel_calls=auto)
    return subsets['train'], subsets['val']


# TRAIN & SUBMISSION >>>

def train_model(model: Union[str, None], epochs: int, initial_epoch: int, mode: str, model_name: str, patience: int,
                steps_per_epoch: int, validation_steps: int, working: str, **kwargs):
    """
    Train the model.
    """

    # Train/val subsets
    train, val = train_val_set(**parameters)

    # Build the model
    if model is not None:
        model = load_model(model)
    else:
        model = Belka(**parameters)
        if mode == 'mlm':
            loss = CategoricalLoss(mask=-1, **parameters)
            metrics = MaskedAUC(mask=-1, multi_label=False, num_labels=None, **parameters)
        elif mode == 'fps':
            loss = BinaryLoss(**parameters)
            metrics = MaskedAUC(mask=-1, multi_label=False, num_labels=None, **parameters)
        else:
            loss = MultiLabelLoss(macro=True, **parameters)
            metrics = MaskedAUC(mask=2, multi_label=True, num_labels=3, **parameters)
        model.compile(optimizer=tf.keras.optimizers.Adam(), loss=loss, metrics=metrics)


    # Callbacks
    suffix = {
        'mlm': '_{epoch:03d}_{loss:.4f}.model.keras',
        'fps': '_{epoch:03d}_{auc:.4f}_{val_auc:.4f}.model.keras',
        'clf': '_{epoch:03d}_{auc:.4f}_{val_auc:.4f}.model.keras'}
    model_saver = keras.callbacks.ModelCheckpoint(
        monitor='loss', mode='min',
        filepath=os.path.join(working, model_name + suffix[mode]),
        save_best_only=False,
        save_weights_only=False)
    early_stopping = keras.callbacks.EarlyStopping(
        monitor='loss', mode='min', patience=patience, restore_best_weights=True)
    learning_rate = tf.keras.callbacks.ReduceLROnPlateau(factor=0.5, patience=5, monitor='loss')
    callbacks = [model_saver, early_stopping, learning_rate]

    # Print model summary
    x, y_true = iter(train).get_next()
    y_pred = model(x)
    print(model.summary())

    # Fit the model
    validation_steps = None if mode == 'mlm' else validation_steps
    model.fit(train, epochs=epochs, initial_epoch=initial_epoch, steps_per_epoch=steps_per_epoch,
              validation_data=val, validation_steps=validation_steps, callbacks=callbacks)

    return model


def make_submission(batch_size: dict, max_length: int, model: str, working: str, **kwargs) -> None:
    """
    Make submission.
    """

    # Make train dataset >>>
    df = read_parquet(subset='test', **parameters).iloc[:1000]
    df['smiles'] = df['smiles'].mapply(lambda x: atomInSmiles.smiles_tokenizer(x))
    ds = tf.data.Dataset.from_tensor_slices(
        {'smiles': tf.ragged.constant(df['smiles'].tolist())})

    # Tokenize -> Zero-pad -> Batch -> Cast
    encoder = get_smiles_encoder(**parameters)
    ds = ds.map(lambda x: tf.cast(encoder(x['smiles']), dtype=tf.int32))
    ds = ds.padded_batch(batch_size=batch_size, padded_shapes=(max_length,))
    ds = ds.prefetch(buffer_size=tf.data.AUTOTUNE)

    # Make predictions >>>
    model = tf.saved_model.load(model)
    pred = np.zeros(shape=(0,3), dtype=np.float32)
    for batch in ds:
        pred = np.concatenate([pred, model.serve(batch)])
    print('\r')

    # Write predictions to csv
    cols = ['BRD4_pred', 'HSA_pred', 'sEH_pred']
    df[cols] = pred
    cols = [['BRD4', 'BRD4_pred'], ['HSA', 'HSA_pred'], ['sEH', 'sEH_pred']]
    df = np.concatenate([df[col].to_numpy() for col in cols], axis=0)
    df = pd.DataFrame(data=df, columns=['id', 'binds'])
    df = df.dropna().sort_values(by='id').reset_index(drop=True)
    df['id'] = df['id'].astype(int)
    df.to_csv(os.path.join(working, 'submission.csv'), index=False)

    return None


# UTILS >>>

def read_parquet(subset: str, root: str, **kwargs) -> pd.DataFrame:
    """
    Read and preprocess train/test parquet files.
    """

    # Read train set
    df = pd.read_parquet(os.path.join(root, f'{subset}.parquet'))
    df = df.rename(columns={
        'buildingblock1_smiles': 'block1',
        'buildingblock2_smiles': 'block2',
        'buildingblock3_smiles': 'block3',
        'molecule_smiles': 'smiles'})

    # Group by molecule -> get multiclass labels
    cols = ['block1', 'block2', 'block3', 'smiles']
    values = 'binds' if subset == 'train' else 'id'
    df = df.pivot(index=cols, columns='protein_name', values=values).reset_index()

    return df


def make_parquet(working: str, seed: int, **kwargs) -> None:
    """
    Make Dask DataFrame:

    * Read and shuffle dataframe.
    * Stack labels binding affinity [BRD4, HSA, sEH]. Nan mask = 2.
    * Validation split (at least one non-shared blocks).
    * Add train/validation/test indicator (0/1/2).
    * Replace [Dy] DNA-linker with [H]
    * Get ECFPs
    * Write to parquet file.

    Source of extra data: https://chemrxiv.org/engage/chemrxiv/article-details/6438943f08c86922ffeffe57
    Processed by: @chemdatafarmer @Hengck23
    """

    def validation_split(x, test: set):
        """
        Get train (0), or validation (1) indicators.
        Train subset: zero-intersection between blocks and "test" blocks.
        Validation subset: non-zero-intersection between blocks and "test" blocks.
        """

        blocks = set(x[col] for col in ['block1', 'block2', 'block3'])
        i = len(blocks.intersection(test))
        i = 0 if i == 0 else 1
        i = np.int8(i)

        return i

    def replace_linker(smiles):
        """
        Replace [Dy] linker with hydrogen.
        """
        smiles = smiles.replace('[Dy]', '[H]')
        smiles = Chem.CanonSmiles(smiles)
        return smiles

    # Iterate over subsets
    dataset = []
    for subset in ['test', 'extra', 'train']:

        # Read parquet/csv
        if subset in ['train', 'test']:
            df = read_parquet(subset=subset, **parameters)
        else:
            df = pd.read_csv(os.path.join(working, 'DNA_Labeled_Data.csv'), usecols=['new_structure', 'read_count'])
            df = df.rename(columns={'new_structure': 'smiles', 'read_count': 'binds'})

        # Stack binding affinity labels
        cols = ['BRD4', 'HSA', 'sEH']
        if subset == 'train':
            df['binds'] = np.stack([df[col].to_numpy() for col in cols], axis=-1, dtype=np.int8).tolist()
        elif subset == 'test':
            df["binds"] = np.tile(np.array([[2, 2, 2]], dtype=np.int8), reps=(df.shape[0], 1)).tolist()
        else:
            df['binds'] = df['binds'].mapply(lambda x: [2, 2, np.clip(x, a_min=0, a_max=1)])
        for col in cols:
            df = df.drop(columns=[col]) if col in df.columns else df

        # Validation split
        if subset == 'train':
            blocks = list(set(df['block1'].to_list()) | set(df['block2'].tolist()) | set(df['block3'].tolist()))
            _, val, _, _ = train_test_split(blocks, blocks, test_size=0.03, random_state=seed)
            df['subset'] = df.mapply(lambda x: validation_split(x, test=val), axis=1)
        elif subset == 'test':
            df['subset'] = 2
        else:
            df['subset'] = 0    # Use extra data only for training
        cols = ['block1', 'block2', 'block3']
        for col in cols:
            df = df.drop(columns=[col]) if col in df.columns else df

        # Replace [Dy] DNA-linker with [H]
        df['smiles_no_linker'] = df['smiles'].mapply(lambda x: replace_linker(smiles=x))

        # Append dataframe to list
        dataset.append(df)

    # Concatenate -> Shuffle -> Convert to Dask Dataframe
    df = pd.concat(dataset)
    df = df.sample(frac=1.0, ignore_index=True, random_state=seed)
    df = dd.from_pandas(data=df)
    df = df.repartition(npartitions=20)

    # Write to parquet
    df.to_parquet(os.path.join(working, 'belka.parquet'), schema={
        'smiles': pa.string(),
        'binds': pa.list_(pa.int8(), 3),
        'subset': pa.int8(),
        'smiles_no_linker': pa.string()})

    return None


def make_dataset(working: str, **kwargs) -> None:
    """
    Make TFR dataset.
    """

    def generator() -> dict:
        for row in df.itertuples(index=False, name='Row'):
            yield {
                'smiles': row.smiles,
                'smiles_no_linker': row.smiles_no_linker,
                'binds': row.binds,
                'subset': row.subset}

    def serialize_smiles(x) -> dict:
        """
        Serialize smiles to string.
        """

        x['smiles'] = tf.io.serialize_tensor(x['smiles'])
        return x

    def get_ecfp(x) -> dict:
        """
        Compute ECFP form "smiles_no_linker".
        """

        x['ecfp'] = transformer(x['smiles_no_linker'])
        x.pop('smiles_no_linker')
        return x

    # Read dataset
    df = dd.read_parquet(os.path.join(working, 'belka.parquet'))
    df = df.compute()

    # Tokenize SMILES
    df['smiles'] = df['smiles'].mapply(lambda x: atomInSmiles.smiles_tokenizer(x))

    # Write to TFRecords
    auto = tf.data.AUTOTUNE
    transformer = FPGenerator()
    ds = tf.data.Dataset.from_generator(
        generator=lambda : generator(),
        output_signature={
            'smiles': tf.TensorSpec(shape=(None,), dtype=tf.string),
            'smiles_no_linker': tf.TensorSpec(shape=(), dtype=tf.string),
            'binds': tf.TensorSpec(shape=(3,), dtype=tf.int8),
            'subset': tf.TensorSpec(shape=(), dtype=tf.int8)})
    ds = ds.map(lambda x: serialize_smiles(x))
    ds = ds.batch(batch_size=1024, num_parallel_calls=auto)
    ds = ds.map(lambda x: get_ecfp(x), num_parallel_calls=auto)
    ds = ds.unbatch()
    ds.save(os.path.join(working, 'belka.tfr'), compression='GZIP')
    return None


def get_vocab(working: str, **kwargs) -> None:
    """
    Get vocabulary for SMILES encoding.
    """

    # Read parquet
    df = dd.read_parquet(os.path.join(working, 'belka.parquet'))
    df = df.compute()

    # Tokenize SMILES -> Get list if unique tokens
    df['smiles'] = df['smiles'].mapply(lambda x: list(set(atomInSmiles.smiles_tokenizer(x))))
    vocab = np.unique(list(itertools.chain.from_iterable(df['smiles'].tolist()))).tolist()
    vocab = pd.DataFrame(data=vocab)
    vocab.to_csv(os.path.join(working, 'vocab.txt'), index=False, header=False)
    return None


def get_smiles_encoder(vocab: str, **kwargs) -> TextVectorization:
    """
    Get TextVectorization SMILES encoder.
    """

    tokenizer = TextVectorization(
        standardize=None,
        split=None,
        vocabulary=vocab)
    return tokenizer


def load_model(model: str, **kwargs) -> tf.keras.Model:
    """
    Load the model.
    """

    model = tf.keras.models.load_model(model, compile=True, custom_objects={
        'Encodings': Encodings,
        'Embeddings': Embeddings,
        'FeedForward': FeedForward,
        'SelfAttention': SelfAttention,
        'EncoderLayer': EncoderLayer,
        'MultiLabelLoss': MultiLabelLoss,
        'CategoricalLoss': CategoricalLoss,
        'BinaryLoss': BinaryLoss,
        'MaskedAUC': MaskedAUC})
    return model


def set_parameters(
        activation: str, batch_size: int, buffer_size: Union[int, float],
        depth: int,
        dropout_rate: float, epochs: int, epsilon: float, initial_epoch: int,
        masking_rate: float, max_length: int, mode: str, model: Union[str, None],
        model_name: str, num_heads: int, num_layers: int,
        patience: int, root: str, seed: int, steps_per_epoch: int, validation_steps: int, vocab: str, vocab_size: int,
        working: str) -> dict:
    """
    Set uniform parameters for the functions in the scope of the project.
    :param mode: Choose from ['clf', 'fps', 'mlm'],
    :param vocab_size: Set to N+2, where N - number of tokens ([PAD] = 0, [MASK] = 1).    :return:
    """

    inputs = {
        'activation': activation,
        'batch_size': batch_size,
        'buffer_size': int(buffer_size),
        'depth': depth,
        'dropout_rate': dropout_rate,
        'epochs': epochs,
        'epsilon': epsilon,
        'initial_epoch': initial_epoch,
        'masking_rate': masking_rate,
        'max_length': max_length,
        'mode': mode,
        'model': model,
        'model_name': model_name,
        'num_heads': num_heads,
        'num_layers': num_layers,
        'patience': patience,
        'root': root,
        'seed': seed,
        'steps_per_epoch': steps_per_epoch,
        'validation_steps': validation_steps,
        'vocab': vocab,
        'vocab_size': vocab_size,
        'working': working}

    return inputs
mapply.init(n_workers=-1, progressbar=True)
parameters = set_parameters(
    root='leash-BELKA',
    working='working/',
    vocab='vocab.txt',
    model=None,
    mode='clf',
    model_name='belka',
    masking_rate=0.15,
    batch_size=2048, buffer_size=1e07,
    epochs=1000, initial_epoch=0, steps_per_epoch=10000, validation_steps=2000,
    max_length=128, vocab_size=43,
    depth=32, dropout_rate=0.1, num_heads=8, num_layers=4, activation='gelu',
    patience=20, epsilon=1e-07, seed=42)


评论
Copyright Created by DataER | 沪ICP备2024052789号-5 | 沪公网安备31010402336337号