Skip to content

Commit

Permalink
Fix LSA issues and harmonize with the text (sryza#104)
Browse files Browse the repository at this point in the history
  • Loading branch information
sryza authored Apr 2, 2017
1 parent 8c1df8c commit e8754e0
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 6 deletions.
4 changes: 4 additions & 0 deletions ch06-lsa/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@
<groupId>edu.umd</groupId>
<artifactId>cloud9</artifactId>
</dependency>
<dependency>
<groupId>info.bliki.wiki</groupId>
<artifactId>bliki-core</artifactId>
</dependency>
<dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-client</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,11 @@ class AssembleDocumentTermMatrix(private val spark: SparkSession) extends Serial
* Returns a document-term matrix where each element is the TF-IDF of the row's document and
* the column's term.
*
* @param docs a DF with two columns: title and text
* @param docTexts a DF with two columns: title and text
*/
def documentTermMatrix(docs: Dataset[(String, String)], stopWordsFile: String, numTerms: Int)
def documentTermMatrix(docTexts: Dataset[(String, String)], stopWordsFile: String, numTerms: Int)
: (DataFrame, Array[String], Map[Long, String], Array[Double]) = {
val terms = contentsToTerms(docs, stopWordsFile)
val terms = contentsToTerms(docTexts, stopWordsFile)

val termsDF = terms.toDF("title", "terms")
val filtered = termsDF.where(size($"terms") > 1)
Expand All @@ -121,8 +121,7 @@ class AssembleDocumentTermMatrix(private val spark: SparkSession) extends Serial

docTermFreqs.cache()

val docIdsDF = docTermFreqs.withColumn("id", monotonically_increasing_id)
val docIds = docIdsDF.select("id", "title").as[(Long, String)].collect().toMap
val docIds = docTermFreqs.rdd.map(_.getString(0)).zipWithUniqueId().map(_.swap).collect().toMap

val idf = new IDF().setInputCol("termFreqs").setOutputCol("tfidfVec")
val idfModel = idf.fit(docTermFreqs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import scala.collection.mutable.ArrayBuffer
object RunLSA {
def main(args: Array[String]): Unit = {
val k = if (args.length > 0) args(0).toInt else 100
val numTerms = if (args.length > 1) args(1).toInt else 50000
val numTerms = if (args.length > 1) args(1).toInt else 20000

val spark = SparkSession.builder().config("spark.serializer", classOf[KryoSerializer].getName).getOrCreate()
val assembleMatrix = new AssembleDocumentTermMatrix(spark)
Expand Down

0 comments on commit e8754e0

Please sign in to comment.