diff --git a/init2winit/model_lib/conformer.py b/init2winit/model_lib/conformer.py index ee451e69..ff877a8e 100644 --- a/init2winit/model_lib/conformer.py +++ b/init2winit/model_lib/conformer.py @@ -402,7 +402,6 @@ def dot_product_attention(query, broadcast_dropout=True, dropout_rng=None, dropout_rate=0., - deterministic=False, dtype=jnp.float32, precision=None, temperature=1.0): @@ -434,7 +433,6 @@ def dot_product_attention(query, broadcast_dropout: bool: use a broadcasted dropout along batch dims. dropout_rng: JAX PRNGKey: to be used for dropout dropout_rate: dropout rate - deterministic: bool, deterministic or not (to apply dropout) dtype: the dtype of the computation (default: float32) precision: numerical precision of the computation see `jax.lax.Precision` for details. @@ -452,11 +450,17 @@ def dot_product_attention(query, # compute attention weights query = QueryScaler(dim=query.shape[-1])(query) - attn_weights = nn.dot_product_attention_weights(query, key, bias, mask, - broadcast_dropout, - dropout_rng, dropout_rate, - deterministic, dtype, - precision) + attn_weights = nn.dot_product_attention_weights( + query, + key, + bias, + mask, + broadcast_dropout, + dropout_rng, + dropout_rate, + dtype, + precision, + ) # return weighted sum over values for each query position return jnp.einsum('...hqk,...khd->...qhd', attn_weights, value, diff --git a/init2winit/model_lib/vit.py b/init2winit/model_lib/vit.py index 9ff72fc9..8ffaf531 100644 --- a/init2winit/model_lib/vit.py +++ b/init2winit/model_lib/vit.py @@ -108,7 +108,6 @@ def dot_product_attention(query, broadcast_dropout=True, dropout_rng=None, dropout_rate=0., - deterministic=False, dtype=jnp.float32, precision=None, temperature=1.0): @@ -139,7 +138,6 @@ def dot_product_attention(query, broadcast_dropout: bool: use a broadcasted dropout along batch dims. dropout_rng: JAX PRNGKey: to be used for dropout dropout_rate: dropout rate - deterministic: bool, deterministic or not (to apply dropout) dtype: the dtype of the computation (default: infer from inputs) precision: numerical precision of the computation see `jax.lax.Precision` for details. @@ -157,8 +155,16 @@ def dot_product_attention(query, # compute attention weights attn_weights = nn.dot_product_attention_weights( - query, key, bias, mask, broadcast_dropout, dropout_rng, dropout_rate, - deterministic, dtype, precision) + query, + key, + bias, + mask, + broadcast_dropout, + dropout_rng, + dropout_rate, + dtype, + precision, + ) # return weighted sum over values for each query position return jnp.einsum('...hqk,...khd->...qhd', attn_weights, value,