Mai Giménez

Mai Giménez, PhD. is a staff research engineer working in large language and multimodal language models at Google Deepmind. She is passionate about building the most useful technology for everyone and her main research interest is in language and the sociotechnical impacts of these models in the real world.

Mai is a former board member of the Spanish Python Association, helped organise several PyConES conferences and is a proud member of the Pyladies.


Sessions

07-14
09:30
90min
Let it rip a diffusion tutorial
Mai Giménez

Implementing high-performance deep learning models often feels like a struggle between readable Python code and the low-level optimizations required for modern GPUs and TPUs. JAX bridges this gap by treating neural networks as pure mathematical transformations. In this session, we will move beyond the abstractions of high-level frameworks to build a Denoising Diffusion Probabilistic Model (DDPM) from the ground up.

We will explore how JAX’s functional programming paradigm is uniquely suited for the stochastic nature of diffusion. You will learn how to:

  • Master the JIT (Just-In-Time) compilation: See how @jax.jit transforms Python functions into optimized XLA kernels for massive speedups.
  • Leverage Vectorized Mapping: Use @jax.vmap to handle data parallelism across batches without the overhead of manual loops.
  • Dissect the Diffusion Pipeline: Step through the forward noise process (SDEs) and the reverse denoising process (Score-matching).
  • Manage State and PRNGs: Navigate JAX’s unique, explicit handling of random number generation and stateless transformations.

This tutorial is designed for Python developers and ML engineers who want to understand the "how" and "why" behind state-of-the-art text-to-image models. You will leave with a deep understanding of the diffusion objective and the practical skills to deploy high-performance model architectures using the JAX ecosystem.

Machine Learning, NLP and CV
Conference Hall Complex B (S4B)
07-14
11:15
90min
Let it rip a diffusion tutorial
Mai Giménez

Implementing high-performance deep learning models often feels like a struggle between readable Python code and the low-level optimizations required for modern GPUs and TPUs. JAX bridges this gap by treating neural networks as pure mathematical transformations. In this session, we will move beyond the abstractions of high-level frameworks to build a Denoising Diffusion Probabilistic Model (DDPM) from the ground up.

We will explore how JAX’s functional programming paradigm is uniquely suited for the stochastic nature of diffusion. You will learn how to:

  • Master the JIT (Just-In-Time) compilation: See how @jax.jit transforms Python functions into optimized XLA kernels for massive speedups.
  • Leverage Vectorized Mapping: Use @jax.vmap to handle data parallelism across batches without the overhead of manual loops.
  • Dissect the Diffusion Pipeline: Step through the forward noise process (SDEs) and the reverse denoising process (Score-matching).
  • Manage State and PRNGs: Navigate JAX’s unique, explicit handling of random number generation and stateless transformations.

This tutorial is designed for Python developers and ML engineers who want to understand the "how" and "why" behind state-of-the-art text-to-image models. You will leave with a deep understanding of the diffusion objective and the practical skills to deploy high-performance model architectures using the JAX ecosystem.

Machine Learning, NLP and CV
Conference Hall Complex B (S4B)