Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flax] Remove redundant deterministic parameter from dot_product_attention. #588

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions init2winit/model_lib/conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,6 @@ def dot_product_attention(query,
broadcast_dropout=True,
dropout_rng=None,
dropout_rate=0.,
deterministic=False,
dtype=jnp.float32,
precision=None,
temperature=1.0):
Expand Down Expand Up @@ -434,7 +433,6 @@ def dot_product_attention(query,
broadcast_dropout: bool: use a broadcasted dropout along batch dims.
dropout_rng: JAX PRNGKey: to be used for dropout
dropout_rate: dropout rate
deterministic: bool, deterministic or not (to apply dropout)
dtype: the dtype of the computation (default: float32)
precision: numerical precision of the computation see `jax.lax.Precision`
for details.
Expand All @@ -452,11 +450,17 @@ def dot_product_attention(query,

# compute attention weights
query = QueryScaler(dim=query.shape[-1])(query)
attn_weights = nn.dot_product_attention_weights(query, key, bias, mask,
broadcast_dropout,
dropout_rng, dropout_rate,
deterministic, dtype,
precision)
attn_weights = nn.dot_product_attention_weights(
query,
key,
bias,
mask,
broadcast_dropout,
dropout_rng,
dropout_rate,
dtype,
precision,
)

# return weighted sum over values for each query position
return jnp.einsum('...hqk,...khd->...qhd', attn_weights, value,
Expand Down
14 changes: 10 additions & 4 deletions init2winit/model_lib/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ def dot_product_attention(query,
broadcast_dropout=True,
dropout_rng=None,
dropout_rate=0.,
deterministic=False,
dtype=jnp.float32,
precision=None,
temperature=1.0):
Expand Down Expand Up @@ -139,7 +138,6 @@ def dot_product_attention(query,
broadcast_dropout: bool: use a broadcasted dropout along batch dims.
dropout_rng: JAX PRNGKey: to be used for dropout
dropout_rate: dropout rate
deterministic: bool, deterministic or not (to apply dropout)
dtype: the dtype of the computation (default: infer from inputs)
precision: numerical precision of the computation see `jax.lax.Precision`
for details.
Expand All @@ -157,8 +155,16 @@ def dot_product_attention(query,

# compute attention weights
attn_weights = nn.dot_product_attention_weights(
query, key, bias, mask, broadcast_dropout, dropout_rng, dropout_rate,
deterministic, dtype, precision)
query,
key,
bias,
mask,
broadcast_dropout,
dropout_rng,
dropout_rate,
dtype,
precision,
)

# return weighted sum over values for each query position
return jnp.einsum('...hqk,...khd->...qhd', attn_weights, value,
Expand Down