# Copyright (C) 2021-2024, Mindee.
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
import math
from copy import deepcopy
from itertools import permutations
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import tensorflow as tf
from tensorflow.keras import Model, layers
from doctr.datasets import VOCABS
from doctr.models.modules.transformer import MultiHeadAttention, PositionwiseFeedForward
from ...classification import vit_s
from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params
from .base import _PARSeq, _PARSeqPostProcessor
__all__ = ["PARSeq", "parseq"]
default_cfgs: Dict[str, Dict[str, Any]] = {
"parseq": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (32, 128, 3),
"vocab": VOCABS["french"],
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/parseq-4152a87e.weights.h5&src=0",
},
}
class CharEmbedding(layers.Layer):
"""Implements the character embedding module
Args:
----
vocab_size: size of the vocabulary
d_model: dimension of the model
"""
def __init__(self, vocab_size: int, d_model: int):
super(CharEmbedding, self).__init__()
self.embedding = layers.Embedding(vocab_size, d_model)
self.d_model = d_model
def call(self, x: tf.Tensor, **kwargs: Any) -> tf.Tensor:
return math.sqrt(self.d_model) * self.embedding(x, **kwargs)
class PARSeqDecoder(layers.Layer):
"""Implements decoder module of the PARSeq model
Args:
----
d_model: dimension of the model
num_heads: number of attention heads
ffd: dimension of the feed forward layer
ffd_ratio: depth multiplier for the feed forward layer
dropout: dropout rate
"""
def __init__(
self,
d_model: int,
num_heads: int = 12,
ffd: int = 2048,
ffd_ratio: int = 4,
dropout: float = 0.1,
):
super(PARSeqDecoder, self).__init__()
self.attention = MultiHeadAttention(num_heads, d_model, dropout=dropout)
self.cross_attention = MultiHeadAttention(num_heads, d_model, dropout=dropout)
self.position_feed_forward = PositionwiseFeedForward(
d_model, ffd * ffd_ratio, dropout, layers.Activation(tf.nn.gelu)
)
self.attention_norm = layers.LayerNormalization(epsilon=1e-5)
self.cross_attention_norm = layers.LayerNormalization(epsilon=1e-5)
self.query_norm = layers.LayerNormalization(epsilon=1e-5)
self.content_norm = layers.LayerNormalization(epsilon=1e-5)
self.feed_forward_norm = layers.LayerNormalization(epsilon=1e-5)
self.output_norm = layers.LayerNormalization(epsilon=1e-5)
self.attention_dropout = layers.Dropout(dropout)
self.cross_attention_dropout = layers.Dropout(dropout)
self.feed_forward_dropout = layers.Dropout(dropout)
def call(
self,
target,
content,
memory,
target_mask=None,
**kwargs: Any,
):
query_norm = self.query_norm(target, **kwargs)
content_norm = self.content_norm(content, **kwargs)
target = target + self.attention_dropout(
self.attention(query_norm, content_norm, content_norm, mask=target_mask, **kwargs), **kwargs
)
target = target + self.cross_attention_dropout(
self.cross_attention(self.query_norm(target, **kwargs), memory, memory, **kwargs), **kwargs
)
target = target + self.feed_forward_dropout(
self.position_feed_forward(self.feed_forward_norm(target, **kwargs), **kwargs), **kwargs
)
return self.output_norm(target, **kwargs)
class PARSeq(_PARSeq, Model):
"""Implements a PARSeq architecture as described in `"Scene Text Recognition
with Permuted Autoregressive Sequence Models" <https://arxiv.org/pdf/2207.06966>`_.
Modified implementation based on the official Pytorch implementation: <https://github.com/baudm/parseq/tree/main`_.
Args:
----
feature_extractor: the backbone serving as feature extractor
vocab: vocabulary used for encoding
embedding_units: number of embedding units
max_length: maximum word length handled by the model
dropout_prob: dropout probability for the decoder
dec_num_heads: number of attention heads in the decoder
dec_ff_dim: dimension of the feed forward layer in the decoder
dec_ffd_ratio: depth multiplier for the feed forward layer in the decoder
input_shape: input shape of the image
exportable: onnx exportable returns only logits
cfg: dictionary containing information about the model
"""
_children_names: List[str] = ["feat_extractor", "postprocessor"]
def __init__(
self,
feature_extractor,
vocab: str,
embedding_units: int,
max_length: int = 32, # different from paper
dropout_prob: float = 0.1,
dec_num_heads: int = 12,
dec_ff_dim: int = 384, # we use it from the original implementation instead of 2048
dec_ffd_ratio: int = 4,
input_shape: Tuple[int, int, int] = (32, 128, 3),
exportable: bool = False,
cfg: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__()
self.vocab = vocab
self.exportable = exportable
self.cfg = cfg
self.max_length = max_length
self.vocab_size = len(vocab)
self.rng = np.random.default_rng()
self.feat_extractor = feature_extractor
self.decoder = PARSeqDecoder(embedding_units, dec_num_heads, dec_ff_dim, dec_ffd_ratio, dropout_prob)
self.embed = CharEmbedding(self.vocab_size + 3, embedding_units) # +3 for SOS, EOS, PAD
self.head = layers.Dense(self.vocab_size + 1, name="head") # +1 for EOS
self.pos_queries = self.add_weight(
shape=(1, self.max_length + 1, embedding_units),
initializer="zeros",
trainable=True,
name="positions",
)
self.dropout = layers.Dropout(dropout_prob)
self.postprocessor = PARSeqPostProcessor(vocab=self.vocab)
def generate_permutations(self, seqlen: tf.Tensor) -> tf.Tensor:
# Generates permutations of the target sequence.
# Translated from https://github.com/baudm/parseq/blob/main/strhub/models/parseq/system.py
# with small modifications
max_num_chars = int(tf.reduce_max(seqlen)) # get longest sequence length in batch
perms = [tf.range(max_num_chars, dtype=tf.int32)]
max_perms = math.factorial(max_num_chars) // 2
num_gen_perms = min(3, max_perms)
if max_num_chars < 5:
# Pool of permutations to sample from. We only need the first half (if complementary option is selected)
# Special handling for max_num_chars == 4 which correctly divides the pool into the flipped halves
if max_num_chars == 4:
selector = [0, 3, 4, 6, 9, 10, 12, 16, 17, 18, 19, 21]
else:
selector = list(range(max_perms))
perm_pool_candidates = list(permutations(range(max_num_chars), max_num_chars))
perm_pool = tf.convert_to_tensor([perm_pool_candidates[i] for i in selector])
# If the forward permutation is always selected, no need to add it to the pool for sampling
perm_pool = perm_pool[1:]
final_perms = tf.stack(perms)
if len(perm_pool):
i = self.rng.choice(len(perm_pool), size=num_gen_perms - len(final_perms), replace=False)
final_perms = tf.concat([final_perms, perm_pool[i[0] : i[1]]], axis=0)
else:
perms.extend([
tf.random.shuffle(tf.range(max_num_chars, dtype=tf.int32)) for _ in range(num_gen_perms - len(perms))
])
final_perms = tf.stack(perms)
comp = tf.reverse(final_perms, axis=[-1])
final_perms = tf.stack([final_perms, comp])
final_perms = tf.transpose(final_perms, perm=[1, 0, 2])
final_perms = tf.reshape(final_perms, shape=(-1, max_num_chars))
sos_idx = tf.zeros([tf.shape(final_perms)[0], 1], dtype=tf.int32)
eos_idx = tf.fill([tf.shape(final_perms)[0], 1], max_num_chars + 1)
combined = tf.concat([sos_idx, final_perms + 1, eos_idx], axis=1)
combined = tf.cast(combined, dtype=tf.int32)
if tf.shape(combined)[0] > 1:
combined = tf.tensor_scatter_nd_update(
combined, [[1, i] for i in range(1, max_num_chars + 2)], max_num_chars + 1 - tf.range(max_num_chars + 1)
)
return combined
def generate_permutations_attention_masks(self, permutation: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
# Generate source and target mask for the decoder attention.
sz = permutation.shape[0]
mask = tf.ones((sz, sz), dtype=tf.float32)
for i in range(sz - 1):
query_idx = int(permutation[i])
masked_keys = permutation[i + 1 :].numpy().tolist()
indices = tf.constant([[query_idx, j] for j in masked_keys], dtype=tf.int32)
mask = tf.tensor_scatter_nd_update(mask, indices, tf.zeros(len(masked_keys), dtype=tf.float32))
source_mask = tf.identity(mask[:-1, :-1])
eye_indices = tf.eye(sz, dtype=tf.bool)
mask = tf.tensor_scatter_nd_update(
mask, tf.where(eye_indices), tf.zeros_like(tf.boolean_mask(mask, eye_indices))
)
target_mask = mask[1:, :-1]
return tf.cast(source_mask, dtype=tf.bool), tf.cast(target_mask, dtype=tf.bool)
def decode(
self,
target: tf.Tensor,
memory: tf.Tensor,
target_mask: Optional[tf.Tensor] = None,
target_query: Optional[tf.Tensor] = None,
**kwargs: Any,
) -> tf.Tensor:
batch_size, sequence_length = target.shape
# apply positional information to the target sequence excluding the SOS token
null_ctx = self.embed(target[:, :1], **kwargs)
content = self.pos_queries[:, : sequence_length - 1] + self.embed(target[:, 1:], **kwargs)
content = self.dropout(tf.concat([null_ctx, content], axis=1), **kwargs)
if target_query is None:
target_query = tf.tile(self.pos_queries[:, :sequence_length], [batch_size, 1, 1])
target_query = self.dropout(target_query, **kwargs)
return self.decoder(target_query, content, memory, target_mask, **kwargs)
@tf.function
def decode_autoregressive(self, features: tf.Tensor, max_len: Optional[int] = None, **kwargs) -> tf.Tensor:
"""Generate predictions for the given features."""
max_length = max_len if max_len is not None else self.max_length
max_length = min(max_length, self.max_length) + 1
b = tf.shape(features)[0]
# Padding symbol + SOS at the beginning
ys = tf.fill(dims=(b, max_length), value=self.vocab_size + 2)
start_vector = tf.fill(dims=(b, 1), value=self.vocab_size + 1)
ys = tf.concat([start_vector, ys], axis=-1)
pos_queries = tf.tile(self.pos_queries[:, :max_length], [b, 1, 1])
query_mask = tf.cast(tf.linalg.band_part(tf.ones((max_length, max_length)), -1, 0), dtype=tf.bool)
pos_logits = []
for i in range(max_length):
# Decode one token at a time without providing information about the future tokens
tgt_out = self.decode(
ys[:, : i + 1],
features,
query_mask[i : i + 1, : i + 1],
target_query=pos_queries[:, i : i + 1],
**kwargs,
)
pos_prob = self.head(tgt_out)
pos_logits.append(pos_prob)
if i + 1 < max_length:
# update ys with the next token
i_mesh, j_mesh = tf.meshgrid(tf.range(b), tf.range(max_length), indexing="ij")
indices = tf.stack([i_mesh[:, i + 1], j_mesh[:, i + 1]], axis=1)
ys = tf.tensor_scatter_nd_update(
ys, indices, tf.cast(tf.argmax(pos_prob[:, -1, :], axis=-1), dtype=tf.int32)
)
# Stop decoding if all sequences have reached the EOS token
# NOTE: `break` isn't correctly translated to Onnx so we don't break here if we want to export
if (
not self.exportable
and max_len is None
and tf.reduce_any(tf.reduce_all(tf.equal(ys, tf.constant(self.vocab_size)), axis=-1))
):
break
logits = tf.concat(pos_logits, axis=1) # (N, max_length, vocab_size + 1)
# One refine iteration
# Update query mask
diag_matrix = tf.eye(max_length)
diag_matrix = tf.cast(tf.logical_not(tf.cast(diag_matrix, dtype=tf.bool)), dtype=tf.float32)
query_mask = tf.cast(tf.concat([diag_matrix[1:], tf.ones((1, max_length))], axis=0), dtype=tf.bool)
sos = tf.fill((tf.shape(features)[0], 1), self.vocab_size + 1)
ys = tf.concat([sos, tf.cast(tf.argmax(logits[:, :-1], axis=-1), dtype=tf.int32)], axis=1)
# Create padding mask for refined target input maskes all behind EOS token as False
# (N, 1, 1, max_length)
mask = tf.cast(tf.equal(ys, self.vocab_size), tf.float32)
first_eos_indices = tf.argmax(mask, axis=1, output_type=tf.int32)
mask = tf.sequence_mask(first_eos_indices + 1, maxlen=ys.shape[-1], dtype=tf.float32)
target_pad_mask = tf.cast(mask[:, tf.newaxis, tf.newaxis, :], dtype=tf.bool)
mask = tf.math.logical_and(target_pad_mask, query_mask[:, : ys.shape[1]])
logits = self.head(self.decode(ys, features, mask, target_query=pos_queries, **kwargs), **kwargs)
return logits # (N, max_length, vocab_size + 1)
def call(
self,
x: tf.Tensor,
target: Optional[List[str]] = None,
return_model_output: bool = False,
return_preds: bool = False,
**kwargs: Any,
) -> Dict[str, Any]:
features = self.feat_extractor(x, **kwargs) # (batch_size, patches_seqlen, d_model)
# remove cls token
features = features[:, 1:, :]
if kwargs.get("training", False) and target is None:
raise ValueError("Need to provide labels during training")
if target is not None:
gt, seq_len = self.build_target(target)
seq_len = tf.cast(seq_len, tf.int32)
gt = gt[:, : int(tf.reduce_max(seq_len)) + 2] # slice up to the max length of the batch + 2 (SOS + EOS)
if kwargs.get("training", False):
# Generate permutations of the target sequences
tgt_perms = self.generate_permutations(seq_len)
gt_in = gt[:, :-1] # remove EOS token from longest target sequence
gt_out = gt[:, 1:] # remove SOS token
# Create padding mask for target input
# [True, True, True, ..., False, False, False] -> False is masked
padding_mask = tf.math.logical_and(
tf.math.not_equal(gt_in, self.vocab_size + 2), tf.math.not_equal(gt_in, self.vocab_size)
)
padding_mask = padding_mask[:, tf.newaxis, tf.newaxis, :] # (N, 1, 1, seq_len)
loss = tf.constant(0.0)
loss_numel = tf.constant(0.0)
n = tf.reduce_sum(tf.cast(tf.math.not_equal(gt_out, self.vocab_size + 2), dtype=tf.float32))
for i, perm in enumerate(tgt_perms):
_, target_mask = self.generate_permutations_attention_masks(perm) # (seq_len, seq_len)
# combine both masks to (N, 1, seq_len, seq_len)
mask = tf.logical_and(padding_mask, tf.expand_dims(tf.expand_dims(target_mask, axis=0), axis=0))
logits = self.head(self.decode(gt_in, features, mask, **kwargs), **kwargs)
logits_flat = tf.reshape(logits, (-1, logits.shape[-1]))
targets_flat = tf.reshape(gt_out, (-1,))
mask = tf.not_equal(targets_flat, self.vocab_size + 2)
loss += n * tf.reduce_mean(
tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=tf.boolean_mask(targets_flat, mask), logits=tf.boolean_mask(logits_flat, mask)
)
)
loss_numel += n
# After the second iteration (i.e. done with canonical and reverse orderings),
# remove the [EOS] tokens for the succeeding perms
if i == 1:
gt_out = tf.where(tf.equal(gt_out, self.vocab_size), self.vocab_size + 2, gt_out)
n = tf.reduce_sum(tf.cast(tf.math.not_equal(gt_out, self.vocab_size + 2), dtype=tf.float32))
loss /= loss_numel
else:
gt = gt[:, 1:] # remove SOS token
max_len = gt.shape[1] - 1 # exclude EOS token
logits = self.decode_autoregressive(features, max_len, **kwargs)
logits_flat = tf.reshape(logits, (-1, logits.shape[-1]))
targets_flat = tf.reshape(gt, (-1,))
mask = tf.not_equal(targets_flat, self.vocab_size + 2)
loss = tf.reduce_mean(
tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=tf.boolean_mask(targets_flat, mask), logits=tf.boolean_mask(logits_flat, mask)
)
)
else:
logits = self.decode_autoregressive(features, **kwargs)
logits = _bf16_to_float32(logits)
out: Dict[str, tf.Tensor] = {}
if self.exportable:
out["logits"] = logits
return out
if return_model_output:
out["out_map"] = logits
if target is None or return_preds:
# Post-process boxes
out["preds"] = self.postprocessor(logits)
if target is not None:
out["loss"] = loss
return out
class PARSeqPostProcessor(_PARSeqPostProcessor):
"""Post processor for PARSeq architecture
Args:
----
vocab: string containing the ordered sequence of supported characters
"""
def __call__(
self,
logits: tf.Tensor,
) -> List[Tuple[str, float]]:
# compute pred with argmax for attention models
out_idxs = tf.math.argmax(logits, axis=2)
preds_prob = tf.math.reduce_max(tf.nn.softmax(logits, axis=-1), axis=-1)
# decode raw output of the model with tf_label_to_idx
out_idxs = tf.cast(out_idxs, dtype="int32")
embedding = tf.constant(self._embedding, dtype=tf.string)
decoded_strings_pred = tf.strings.reduce_join(inputs=tf.nn.embedding_lookup(embedding, out_idxs), axis=-1)
decoded_strings_pred = tf.strings.split(decoded_strings_pred, "<eos>")
decoded_strings_pred = tf.sparse.to_dense(decoded_strings_pred.to_sparse(), default_value="not valid")[:, 0]
word_values = [word.decode() for word in decoded_strings_pred.numpy().tolist()]
# compute probabilties for each word up to the EOS token
probs = [
preds_prob[i, : len(word)].numpy().clip(0, 1).mean().item() if word else 0.0
for i, word in enumerate(word_values)
]
return list(zip(word_values, probs))
def _parseq(
arch: str,
pretrained: bool,
backbone_fn,
input_shape: Optional[Tuple[int, int, int]] = None,
**kwargs: Any,
) -> PARSeq:
# Patch the config
_cfg = deepcopy(default_cfgs[arch])
_cfg["input_shape"] = input_shape or _cfg["input_shape"]
_cfg["vocab"] = kwargs.get("vocab", _cfg["vocab"])
patch_size = kwargs.get("patch_size", (4, 8))
kwargs["vocab"] = _cfg["vocab"]
# Feature extractor
feat_extractor = backbone_fn(
# NOTE: we don't use a pretrained backbone for non-rectangular patches to avoid the pos embed mismatch
pretrained=False,
input_shape=_cfg["input_shape"],
patch_size=patch_size,
include_top=False,
)
kwargs.pop("patch_size", None)
kwargs.pop("pretrained_backbone", None)
# Build the model
model = PARSeq(feat_extractor, cfg=_cfg, **kwargs)
_build_model(model)
# Load pretrained parameters
if pretrained:
# The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
load_pretrained_params(
model, default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"]
)
return model
[docs]
def parseq(pretrained: bool = False, **kwargs: Any) -> PARSeq:
"""PARSeq architecture from
`"Scene Text Recognition with Permuted Autoregressive Sequence Models" <https://arxiv.org/pdf/2207.06966>`_.
>>> import tensorflow as tf
>>> from doctr.models import parseq
>>> model = parseq(pretrained=False)
>>> input_tensor = tf.random.uniform(shape=[1, 32, 128, 3], maxval=1, dtype=tf.float32)
>>> out = model(input_tensor)
Args:
----
pretrained (bool): If True, returns a model pre-trained on our text recognition dataset
**kwargs: keyword arguments of the PARSeq architecture
Returns:
-------
text recognition architecture
"""
return _parseq(
"parseq",
pretrained,
vit_s,
embedding_units=384,
patch_size=(4, 8),
**kwargs,
)