From 1071a41cb4f75c34d8cd4752db93dffb346bca3d Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Mon, 25 Sep 2023 07:50:45 -0700 Subject: [PATCH] [flax] Remove redundant `deterministic` parameter from `dot_product_attention`. PiperOrigin-RevId: 568217601 --- init2winit/model_lib/conformer.py | 18 +++++++++++------- init2winit/model_lib/vit.py | 14 ++++++++++---- 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/init2winit/model_lib/conformer.py b/init2winit/model_lib/conformer.py index ee451e69..ff877a8e 100644 --- a/init2winit/model_lib/conformer.py +++ b/init2winit/model_lib/conformer.py @@ -402,7 +402,6 @@ def dot_product_attention(query, broadcast_dropout=True, dropout_rng=None, dropout_rate=0., - deterministic=False, dtype=jnp.float32, precision=None, temperature=1.0): @@ -434,7 +433,6 @@ def dot_product_attention(query, broadcast_dropout: bool: use a broadcasted dropout along batch dims. dropout_rng: JAX PRNGKey: to be used for dropout dropout_rate: dropout rate - deterministic: bool, deterministic or not (to apply dropout) dtype: the dtype of the computation (default: float32) precision: numerical precision of the computation see `jax.lax.Precision` for details. @@ -452,11 +450,17 @@ def dot_product_attention(query, # compute attention weights query = QueryScaler(dim=query.shape[-1])(query) - attn_weights = nn.dot_product_attention_weights(query, key, bias, mask, - broadcast_dropout, - dropout_rng, dropout_rate, - deterministic, dtype, - precision) + attn_weights = nn.dot_product_attention_weights( + query, + key, + bias, + mask, + broadcast_dropout, + dropout_rng, + dropout_rate, + dtype, + precision, + ) # return weighted sum over values for each query position return jnp.einsum('...hqk,...khd->...qhd', attn_weights, value, diff --git a/init2winit/model_lib/vit.py b/init2winit/model_lib/vit.py index 9ff72fc9..8ffaf531 100644 --- a/init2winit/model_lib/vit.py +++ b/init2winit/model_lib/vit.py @@ -108,7 +108,6 @@ def dot_product_attention(query, broadcast_dropout=True, dropout_rng=None, dropout_rate=0., - deterministic=False, dtype=jnp.float32, precision=None, temperature=1.0): @@ -139,7 +138,6 @@ def dot_product_attention(query, broadcast_dropout: bool: use a broadcasted dropout along batch dims. dropout_rng: JAX PRNGKey: to be used for dropout dropout_rate: dropout rate - deterministic: bool, deterministic or not (to apply dropout) dtype: the dtype of the computation (default: infer from inputs) precision: numerical precision of the computation see `jax.lax.Precision` for details. @@ -157,8 +155,16 @@ def dot_product_attention(query, # compute attention weights attn_weights = nn.dot_product_attention_weights( - query, key, bias, mask, broadcast_dropout, dropout_rng, dropout_rate, - deterministic, dtype, precision) + query, + key, + bias, + mask, + broadcast_dropout, + dropout_rng, + dropout_rate, + dtype, + precision, + ) # return weighted sum over values for each query position return jnp.einsum('...hqk,...khd->...qhd', attn_weights, value,