Skip to content

Commit

Permalink
Test Dropout and BatchNorm layers export
Browse files Browse the repository at this point in the history
  • Loading branch information
juliabeliaeva committed Apr 24, 2022
1 parent 0491460 commit 08ebcb8
Showing 1 changed file with 43 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Copyright 2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
*/

package org.jetbrains.kotlinx.dl.api.inference.keras

import org.jetbrains.kotlinx.dl.api.core.Sequential
import org.jetbrains.kotlinx.dl.api.core.initializer.HeNormal
import org.jetbrains.kotlinx.dl.api.core.initializer.HeUniform
import org.jetbrains.kotlinx.dl.api.core.layer.core.Input
import org.jetbrains.kotlinx.dl.api.core.layer.normalization.BatchNorm
import org.jetbrains.kotlinx.dl.api.core.layer.regularization.Dropout
import org.jetbrains.kotlinx.dl.api.core.regularizer.L2
import org.junit.jupiter.api.Test

class DropoutBatchNormImportExportTest {
@Test
fun dropout() {
LayerImportExportTest.run(
Sequential.of(
Input(10),
Dropout(name = "test_dropout", rate = 0.2f)
)
)
}

@Test
fun batchNorm() {
LayerImportExportTest.run(
Sequential.of(
Input(10),
BatchNorm(
name = "test_batch_norm", axis = listOf(1),
momentum = 0.9, center = false, epsilon = 0.002, scale = false,
gammaInitializer = HeUniform(), betaInitializer = HeNormal(),
betaRegularizer = L2(), gammaRegularizer = L2(),
movingMeanInitializer = HeNormal(), movingVarianceInitializer = HeUniform()
)
)
)
}
}

0 comments on commit 08ebcb8

Please sign in to comment.