diff --git a/ch06-lsa/pom.xml b/ch06-lsa/pom.xml index 5b43fea9..9add4b4b 100644 --- a/ch06-lsa/pom.xml +++ b/ch06-lsa/pom.xml @@ -38,6 +38,10 @@ edu.umd cloud9 + + info.bliki.wiki + bliki-core + org.apache.hadoop hadoop-client diff --git a/ch06-lsa/src/main/scala/com/cloudera/datascience/lsa/AssembleDocumentTermMatrix.scala b/ch06-lsa/src/main/scala/com/cloudera/datascience/lsa/AssembleDocumentTermMatrix.scala index 169e92bb..6bd6abe3 100644 --- a/ch06-lsa/src/main/scala/com/cloudera/datascience/lsa/AssembleDocumentTermMatrix.scala +++ b/ch06-lsa/src/main/scala/com/cloudera/datascience/lsa/AssembleDocumentTermMatrix.scala @@ -103,11 +103,11 @@ class AssembleDocumentTermMatrix(private val spark: SparkSession) extends Serial * Returns a document-term matrix where each element is the TF-IDF of the row's document and * the column's term. * - * @param docs a DF with two columns: title and text + * @param docTexts a DF with two columns: title and text */ - def documentTermMatrix(docs: Dataset[(String, String)], stopWordsFile: String, numTerms: Int) + def documentTermMatrix(docTexts: Dataset[(String, String)], stopWordsFile: String, numTerms: Int) : (DataFrame, Array[String], Map[Long, String], Array[Double]) = { - val terms = contentsToTerms(docs, stopWordsFile) + val terms = contentsToTerms(docTexts, stopWordsFile) val termsDF = terms.toDF("title", "terms") val filtered = termsDF.where(size($"terms") > 1) @@ -121,8 +121,7 @@ class AssembleDocumentTermMatrix(private val spark: SparkSession) extends Serial docTermFreqs.cache() - val docIdsDF = docTermFreqs.withColumn("id", monotonically_increasing_id) - val docIds = docIdsDF.select("id", "title").as[(Long, String)].collect().toMap + val docIds = docTermFreqs.rdd.map(_.getString(0)).zipWithUniqueId().map(_.swap).collect().toMap val idf = new IDF().setInputCol("termFreqs").setOutputCol("tfidfVec") val idfModel = idf.fit(docTermFreqs) diff --git a/ch06-lsa/src/main/scala/com/cloudera/datascience/lsa/RunLSA.scala b/ch06-lsa/src/main/scala/com/cloudera/datascience/lsa/RunLSA.scala index b00a2425..5509991c 100644 --- a/ch06-lsa/src/main/scala/com/cloudera/datascience/lsa/RunLSA.scala +++ b/ch06-lsa/src/main/scala/com/cloudera/datascience/lsa/RunLSA.scala @@ -20,7 +20,7 @@ import scala.collection.mutable.ArrayBuffer object RunLSA { def main(args: Array[String]): Unit = { val k = if (args.length > 0) args(0).toInt else 100 - val numTerms = if (args.length > 1) args(1).toInt else 50000 + val numTerms = if (args.length > 1) args(1).toInt else 20000 val spark = SparkSession.builder().config("spark.serializer", classOf[KryoSerializer].getName).getOrCreate() val assembleMatrix = new AssembleDocumentTermMatrix(spark)