A Practical Future of Developing AI for Medical Devices
Bridging knowledge gaps and revealing processes when developing robust and long-term applications.
This article aims to bridge several knowledge gaps. The first is the application of language models to medical devices within the regulatory frameworks laid out by bodies like FDA. The second is the layout and function of these models themselves, with the goal of demystifying the function of artificial intelligence (AI). The second section of this article is meant to pull back the veil, revealing processes that are understandable, which is the best place to start when developing robust and long-term applications.
Language models and their derivatives have been captivating imaginations across every industry. Colloquially referred to as AI, it’s difficult to find anyone without a take on them or their potential applications.
As a Machine Learning (ML)/AI focused software engineer in a medical device development company, my job is to understand how language models work, and how to make the technology dance on our strings. I work to minimize the risk of using language models in medical applications to both patients and clients. Many clients are intrigued by the possibilities and curious about how they might be able to incorporate this technology into their devices.
Language models in medical device development
Regulatory bodies take a risk-based approach towards approving medical devices. Non-determinism due to sampling, and general risk of incorrect outputs based on inputs from a domain as broad and untestable as natural language tend to make them nervous. If you plan to incorporate AI into a medical device as a component with any risk of affecting a user’s wellbeing, you’ll find yourself very constrained.
To train a model that gives long, coherent output while only using the most probable next token (greedy sampling), you’ll need a large model, and lots of training data — Meta’s current standard is 15T tokens — and of course lots of very expensive hardware.
Training a large language model (LLM) from scratch is prohibitively expensive, unless you’re a giant tech incumbent, or funded by one. Instead, your likely best bet is to use a model trained by a third party. These third parties range in scope and grandeur from Meta, Microsoft to literally anybody on the internet.
These larger, generalized models will likely still need additional fine-tuning for your own applications. Thankfully, fine-tuning can be made significantly cheaper and easier than full training, however you’ll still need text data and hardware, either rented or owned.
Several large third parties offer fully functional and incredibly large LLMs. These services sometimes also promise to enable fine-tuning of their models based on smaller sets of examples than would be necessary for smaller models, or without any data at all.
Use of these services comes with the chance the model you are using may change, replacing your verified outputs with potentially brand new ones. Additionally, transmission of the user’s text may run afoul of privacy requirements — either from regulatory or business pressures — the constraints of the device requirements themselves, and at very least presents a significantly increased cyberattack surface. Finally, resting your entire device or company on an uncontrolled third-party web-service poses an inherent, disastrous business risk, and risk to the function of the device.
So, what can you do?
While smaller models might be cute-but-useless for generative language modelling, they are quite powerful for rapid classification of natural language. Models on the scale of 150M to 500M parameters can perform quite well, and significantly more quickly than their larger counterparts.
This use involves replacing the output head, which you’ll remember only gives an output corresponding to the likelihood of each token being generated next, with one that converts the fully deterministic output of the final decoder layer to a new output that can correspond to anything. This can be a string of numbers that represents the meaning of the text, or it could be a single probability value corresponding to something like the “amount of hope” represented in the text — anything as long as you have the data to train it for that task.
Also, this small scale allows you to use a larger language model to generate synthetic training data, which can turn a handful of examples into a much larger dataset. While model collapse is a real problem, your 150M parameter model will not be significantly poisoned by outputs produced from a model 50 (or over 1010) times larger.
Development of much smaller-scale clinician-assisting classification systems is nothing new. In fact, a majority of FDA-approved ML/AI-enabled medical devices are components of imaging systems intended to process images and make non-binding predictions to simplify workflows — working together with a human in the loop.
Implications for medical device development
Advances in language modeling have spawned a very lucrative space in the tech industry, however adoption into the medical device industry has, and should, lag significantly. This is driven by their architecture, and the compromises required for their usage — however some portion of this is also driven by the desire to put them to the wrong use.
Generative AI can be flashy, but skeptically we can look towards the much more mature field of computer vision for some glimpse at the future of our current AI modeling methods. Regulatory bodies clearly see the value of ML/AI devices as methods to improve the speed and standard of care with experts driving the decision making.
If you have an interest in applying this technology yourself, read on. The next section breaks down the architecture of transformers-based language models, including both conceptual and code examples, along with a link to a fully functional harness for foundational training, so you can experiment yourself.
Modeling language with transformers
"To deal with hyper-planes in a 14-dimensional space, visualize a 3D space and say 14 to yourself very loudly. Everyone does it." – Geoffrey Hinton, A geometrical view of perceptrons.
On a very high level, language models can be thought of as taking a walk through a linguistic space. The inputs are a series of points that correspond to each input token, and the output corresponds to a new point, or mix of token positions, that could continue it.
Transformers, which form the backbone of the most recent trend in language modeling, are principally signal prediction models. They excel at general forecasting and by representing language as a series of points through a space, we’re really turning language into another signal for the model to forecast, independent of what the language itself actually is.
Full implementation
You can find a fully working, generalized language model training system with expanded implementations, end-to-end training data pipelines, and extensive documentation.
Its primary focus is on the small language model regime, and by default will train and save a 220M parameter language model based on Meta's Llama architecture. Default training requires approximately 24Gb of VRAM (one RTX 4090 – it’s still consumer grade, albeit stretching the one GPU claim of this article), though this can be adjusted to fit your own hardware by modifying the batch size and gradient accumulation values.
To run the model from a saved checkpoint requires significantly less horsepower and can be feasibly done purely on a CPU, with <1Gb of RAM.
Tokens: Language model hieroglyphics
Language families across the world can have radically different representations of meaning. Many languages construct words — the basic component of language that contains meaning — out of small individual symbols. Other language groups have different approaches, and while the quanta may vary, it all comes out to the same thing: communication.
Language models have their own types of hieroglyphs, called tokens. When representing English, these are usually common words, or chunks of letters, numbers, and spaces. Language models do not learn English, instead they learn how to walk through a linguistic space by following these tokens.
The tokenization step is principally an intermediary step. This form doesn’t persist very far into the model itself, which learns its own representation — a point in a 1000+-dimensional space — for each token. But the first and last steps when passing a prompt through a language model involve mapping from and to these tokens, which make the inputs and outputs human interpretable.
Modeling language as a signal
Over the course of training, the model’s mapping between token and point is learned statistically from the data. The model also learns how paths move through this latent meaning, or semantic, space, and how to nudge the last position it saw into the right spot for the next token, based on the path the words have taken to get there.
Every aspect of a modern language model is oriented towards either translating a token into, or out of, this semantic space, or towards accumulating small nudges in this space to determine the next step on the walk.
The architecture
The components we’ll walk through here make up the majority of modern language, and multimodal, models, and the architecture of the models, e.g. how we stack these components together is pretty consistent between models of wildly different sizes developed by wildly different companies.
Embeddings
Embeddings translate a discrete token — corresponding to a word or sub-word — into a position in a multi-dimensional space. While in 3D space, points can be similar in up to three different dimensions. The space into which tokens are placed contain many more dimensions which correspond to many different potential relations between groups.
To force the model’s output to occupy the same token space, try tying the translation in this layer with a detranslation in the output layer. While this step isn’t technically necessary, it means the model only has to learn one position for every token.
When we pass a tokenized sentence into this layer, we end up with a sequence of positions, a trajectory in the model’s semantic space — the path that the model is trying to continue. How well the model can learn this path, using both the information learned in its training data and the previous information provided as the input trajectory, determines how well the model performs at modelling language.
Implementation
In PyTorch, embeddings are a native module. This simplifies our implementation greatly as we only need to include our own dropout regularization which will randomly zap parts of the layer’s output. This forces a degree of robustness in the embeddings since the model still needs to do a good job, even with incomplete information.
class WordTokenEmbeddings(nn.Module):
def __init__(self, config):
super().__init__()
self.vocab_size = config.vocab_size
self.n_embed = config.n_embd
self.wte = nn.Embedding(self.vocab_size, self.n_embed)
self.drop = nn.Dropout(config.dropout)
def forward(self, idx):
tok_emb = self.wte(idx)
return self.drop(tok_emb)
Decoding layers
These layers make up the midsection of the language model, stacked end-to-end — each taking the output of the previous and making its own small adjustment to it. These nudges are accumulated as the data passes through each of the decoders, and finally the result is normalized before being compared to the model’s vocabulary.
Each decoder is made up of two sub-components that each produce a nudge in the data.
Implementation
Each decoder is a straightforward arrangement of more complex components — attention layer and a multi-layer perceptron, or MLP).
First, the attention layer calculates its nudge which is added on top of the input signal. This is then passed through the feed-forward MLP, producinf its own adjustment. This style of model, where adjustments are calculated and then added on top of the inputs, is called residual and has proven to be incredibly powerful.
class Decoder(nn.Module):
"""
Simplified standard transformers decoder structure (sans normalization)
Forward involves attention residual followed by feed forward residual to be
passed to next block in transformer
"""
def __init__(self, config):
super().__init__()
self.attn = MHASelfAttention(config)
self.ff = MLP(config)
def forward(self, x):
x = x + self.attn(x)
x = x + self.ff(x)
return x
The attention layer
Attention is the main ingredient that makes transformers so powerful. This component is split into different attention heads. Each attention head uses the input trajectory to produce part of a nudge towards the predicted next step.
Each attention head computes query, key, and value vectors for each input token. From the query and key vectors, the head computes the attention pattern — a map of the scale of the effect that each step of the input has on every other step. At a high level, it’s a word-to-word importance map. Multiplying this pattern by the head’s value vector produces the result.
The results for each head are lined up side-by-side to reassemble the full dimensionality of the model’s semantic space. Finally, an output projection layer mixes the individual head outputs together into a prediction.
Implementation
In newer versions of PyTorch (>2.0) the scaled dot-product attention function is included natively. Unfortunately, this doesn’t provide as much insight into how attention works compared to the manual implementation. Boy is it fast, though.
class MHASelfAttention(nn.Module):
"""
LLama-style Multi-Head Attention with rotary embeddings
https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L577
Modified to include manual attention fallback
"""
def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_head == 0
# Key, query, value projections for all heads, but in a batch
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
# Output projection
self.o_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
# Regularization
self.resid_dropout = nn.Dropout(config.dropout)
self.n_head = config.n_head
self.n_embd = config.n_embd
self.dropout = config.dropout
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
if not self.flash:
print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
# Causal mask to ensure that attention is only applied to the left in the input sequence
self.register_buffer("bias", torch.tril(torch.ones(config.max_seq_len, config.max_seq_len))
.view(1, 1, config.max_seq_len, config.max_seq_len))
def forward(self, x, position_embeddings = None):
B, T, C = x.size() # Batch size, sequence length, embedding dimensionality (n_embd)
# Calculate query, key, values for all heads in batch and move head forward to be the batch dim
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
# Causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
if self.flash:
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
else:
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
att = F.softmax(att, dim=-1)
att = self.attn_dropout(att)
y = att @ v
y = y.transpose(1, 2).contiguous().view(B, T, C)
y = self.resid_dropout(self.o_proj(y))
return y
The MLP/feed-forward layer
This component calculates a new residual on top of the output of the preceding attention layer, further modifying the layer’s nudge. Unlike attention, this layer does not take the past into account, however it can use information stored in the present position by the previous attention layers to modify the decoder’s overall movement.
Implementation
While MLPs are conceptually quite simple, how they operate as components of a more complex model is very emphatically not simple or straightforward.
class MLP(nn.Module):
"""
GPT-2 style MLP feed-forward block
https://github.com/karpathy/nanoGPT/blob/master/model.py#L78
Expands embedding dimension to intermediate size, then collapses
back to pass to next block in transformer
"""
def __init__(self, config):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
self.gelu = nn.GELU()
self.c_proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x):
x = self.c_fc(x)
x = self.gelu(x)
x = self.c_proj(x)
x = self.dropout(x)
return x
Output Layer
The detranslation layer is also called the output head. It takes the final point in the semantic space and maps it back to a mix of specific tokens that are most similar to that point. By using the same token-position mapping as the embedding layer, we turn both the input prompt and its generated continuation into a path through one consistent space, similar to how an accelerometer measures a signal in three dimensions.
Implementation
The output layer is a simple linear layer that transforms the final position output from the final decoder layer into a mix of token probabilities, based on the dot-product similarity between each token’s position vector, and that of the predicted prompt continuation.
"""
Language Modelling "head"
maps the hidden state output by the transformer decoder stack
to a token from its vocabulary
"""
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
"""
"Weight Tying"
From Karpathy, which credits https://paperswithcode.com/method/weight-tying
By making input embeddings and output decodings map within the same space,
the generation of a new token is more like a "movement" through this latent space.
The word embeddings then must be clustered by frequency/semantic usage, with n_embd
different potential similarities or differences(/nuances).
"""
self.transformer.wte.wte.weight = self.lm_head.weight
Language models in practice
How does all this math hash out in practice?
In the previous section, I abused the concept of language modelling being the task of predicting the next point in a trajectory, however this analogy gets a bit hairy when we’re dealing with spaces with more than three dimensions. No 3D representation of 1k+ dimensions will be ideal, but it can still be illustrative.
In the animation below, the three principal components of a sentence in a model's 2000+ dimension semantic space are shown.
Animation courtesy of Thor Tronrud.
Black text and lines represent the input, while red represents the most probable predictions as determined by the model. As one can see, the model prefers to repetitively retread one specific path from the input. The problem is that the most probable path isn’t necessarily the best. In many cases, a wide range of different tokens have similar likelihoods of being the next step, and picking one that is potentially only marginally higher in probability locks the model into these repeating modes.
We can knock the model out of this behavior by randomly choosing the next token based on the confidence the model has that they’re the proper continuation. Now we see many more of the transitions from the inputs represented.
Animation courtesy of Thor Tronrud.
Summary
Regulatory bodies see the value of ML/AI devices as methods to improve the speed and standard of care with experts driving the decision making. The application of language models to medical devices must operate within the regulatory frameworks laid out by bodies like the FDA.
This article pulls back the ML/AI veil, revealing understandable processes and the layout and function of models which are the best place to start to develop robust and long-term applications.
About the author:
Image courtesy of Thor Tronrud.
Thor Tronrud is a research and data analysis-focused software engineer at StarFish Medical who specializes in the development and application of machine learning tools. Previously an astrophysicist working with magnetohydrodynamic simulations, Tronrud joined StarFish in 2021 and has applied machine learning techniques to problems including image segmentation, signal analysis, and language processing.
About the Author
You May Also Like