diff --git a/mr/src/main/java/org/elasticsearch/hadoop/mr/EsInputFormat.java b/mr/src/main/java/org/elasticsearch/hadoop/mr/EsInputFormat.java index cde07ec01..555859415 100644 --- a/mr/src/main/java/org/elasticsearch/hadoop/mr/EsInputFormat.java +++ b/mr/src/main/java/org/elasticsearch/hadoop/mr/EsInputFormat.java @@ -412,7 +412,7 @@ public EsInputRecordReader createRecordReader(InputSplit split, TaskAttemp public org.apache.hadoop.mapred.InputSplit[] getSplits(JobConf job, int numSplits) throws IOException { Settings settings = HadoopSettingsManager.loadFrom(job); - Collection partitions = RestService.findPartitions(settings, log); + Collection partitions = RestService.findPartitions(settings, log, null); EsInputSplit[] splits = new EsInputSplit[partitions.size()]; int index = 0; diff --git a/mr/src/main/java/org/elasticsearch/hadoop/rest/RestClient.java b/mr/src/main/java/org/elasticsearch/hadoop/rest/RestClient.java index 1e05bfad4..d9d7a13d4 100644 --- a/mr/src/main/java/org/elasticsearch/hadoop/rest/RestClient.java +++ b/mr/src/main/java/org/elasticsearch/hadoop/rest/RestClient.java @@ -311,14 +311,22 @@ public List>> targetShards(String index, String routing } public MappingSet getMappings(Resource indexResource) { + return getMappings(indexResource, Collections.emptyList()); + } + + public MappingSet getMappings(Resource indexResource, Collection includeFields) { if (indexResource.isTyped()) { - return getMappings(indexResource.index() + "/_mapping/" + indexResource.type(), true); + return getMappings(indexResource.index() + "/_mapping/" + indexResource.type(), true, includeFields); } else { - return getMappings(indexResource.index() + "/_mapping" + (indexReadMissingAsEmpty ? "?ignore_unavailable=true" : ""), false); + return getMappings(indexResource.index() + "/_mapping" + (indexReadMissingAsEmpty ? "?ignore_unavailable=true" : ""), false, includeFields); } } public MappingSet getMappings(String query, boolean includeTypeName) { + return getMappings(query, includeTypeName, Collections.emptyList()); + } + + public MappingSet getMappings(String query, boolean includeTypeName, Collection includeFields) { // If the version is not at least 7, then the property isn't guaranteed to exist. If it is, then defer to the flag. boolean requestTypeNameInResponse = clusterInfo.getMajorVersion().onOrAfter(EsMajorVersion.V_7_X) && includeTypeName; // Response will always have the type name in it if node version is before 7, and if it is not, defer to the flag. @@ -328,7 +336,7 @@ public MappingSet getMappings(String query, boolean includeTypeName) { } Map result = get(query, null); if (result != null && !result.isEmpty()) { - return FieldParser.parseMappings(result, typeNameInResponse); + return FieldParser.parseMappings(result, typeNameInResponse, includeFields); } return null; } diff --git a/mr/src/main/java/org/elasticsearch/hadoop/rest/RestRepository.java b/mr/src/main/java/org/elasticsearch/hadoop/rest/RestRepository.java index 23609a4cf..47248aaa3 100644 --- a/mr/src/main/java/org/elasticsearch/hadoop/rest/RestRepository.java +++ b/mr/src/main/java/org/elasticsearch/hadoop/rest/RestRepository.java @@ -54,11 +54,7 @@ import java.io.Closeable; import java.io.IOException; import java.io.InputStream; -import java.util.Collections; -import java.util.HashMap; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; +import java.util.*; import java.util.Map.Entry; import static org.elasticsearch.hadoop.rest.Request.Method.POST; @@ -300,6 +296,10 @@ public MappingSet getMappings() { return client.getMappings(resources.getResourceRead()); } + public MappingSet getMappings(Collection includeFields) { + return client.getMappings(resources.getResourceRead(), includeFields); + } + public Map sampleGeoFields(Mapping mapping) { Map fields = MappingUtils.geoFields(mapping); Map geoMapping = client.sampleForFields(resources.getResourceRead(), fields.keySet()); diff --git a/mr/src/main/java/org/elasticsearch/hadoop/rest/RestService.java b/mr/src/main/java/org/elasticsearch/hadoop/rest/RestService.java index d0b5ad58b..2fb5f8078 100644 --- a/mr/src/main/java/org/elasticsearch/hadoop/rest/RestService.java +++ b/mr/src/main/java/org/elasticsearch/hadoop/rest/RestService.java @@ -212,7 +212,7 @@ public void remove() { } @SuppressWarnings("unchecked") - public static List findPartitions(Settings settings, Log log) { + public static List findPartitions(Settings settings, Log log, Mapping resolvedMapping) { Version.logVersion(); InitializationUtils.validateSettings(settings); @@ -244,16 +244,18 @@ public static List findPartitions(Settings settings, Log lo log.info(String.format("Reading from [%s]", settings.getResourceRead())); - MappingSet mapping = null; + Mapping mapping = resolvedMapping; if (!shards.isEmpty()) { - mapping = client.getMappings(); + if (mapping == null) { + mapping = client.getMappings().getResolvedView(); + } if (log.isDebugEnabled()) { - log.debug(String.format("Discovered resolved mapping {%s} for [%s]", mapping.getResolvedView(), settings.getResourceRead())); + log.debug(String.format("Discovered resolved mapping {%s} for [%s]", mapping, settings.getResourceRead())); } // validate if possible FieldPresenceValidation validation = settings.getReadFieldExistanceValidation(); if (validation.isRequired()) { - MappingUtils.validateMapping(SettingsUtils.determineSourceFields(settings), mapping.getResolvedView(), validation, log); + MappingUtils.validateMapping(SettingsUtils.determineSourceFields(settings), mapping, validation, log); } } final Map nodesMap = new HashMap(); @@ -278,9 +280,8 @@ public static List findPartitions(Settings settings, Log lo /** * Create one {@link PartitionDefinition} per shard for each requested index. */ - static List findShardPartitions(Settings settings, MappingSet mappingSet, Map nodes, + static List findShardPartitions(Settings settings, Mapping resolvedMapping, Map nodes, List>> shards, Log log) { - Mapping resolvedMapping = mappingSet == null ? null : mappingSet.getResolvedView(); List partitions = new ArrayList(shards.size()); PartitionDefinition.PartitionDefinitionBuilder partitionBuilder = PartitionDefinition.builder(settings, resolvedMapping); for (List> group : shards) { @@ -316,13 +317,12 @@ static List findShardPartitions(Settings settings, MappingS /** * Partitions the query based on the max number of documents allowed per partition {@link Settings#getMaxDocsPerPartition()}. */ - static List findSlicePartitions(RestClient client, Settings settings, MappingSet mappingSet, + static List findSlicePartitions(RestClient client, Settings settings, Mapping resolvedMapping, Map nodes, List>> shards, Log log) { QueryBuilder query = QueryUtils.parseQueryAndFilters(settings); Integer maxDocsPerPartition = settings.getMaxDocsPerPartition(); Assert.notNull(maxDocsPerPartition, "Attempting to find slice partitions but maximum documents per partition is not set."); Resource readResource = new Resource(settings, true); - Mapping resolvedMapping = mappingSet == null ? null : mappingSet.getResolvedView(); PartitionDefinition.PartitionDefinitionBuilder partitionBuilder = PartitionDefinition.builder(settings, resolvedMapping); List partitions = new ArrayList(shards.size()); diff --git a/mr/src/main/java/org/elasticsearch/hadoop/serialization/dto/mapping/FieldParser.java b/mr/src/main/java/org/elasticsearch/hadoop/serialization/dto/mapping/FieldParser.java index 8a1cd7763..99f0114a0 100644 --- a/mr/src/main/java/org/elasticsearch/hadoop/serialization/dto/mapping/FieldParser.java +++ b/mr/src/main/java/org/elasticsearch/hadoop/serialization/dto/mapping/FieldParser.java @@ -19,10 +19,7 @@ package org.elasticsearch.hadoop.serialization.dto.mapping; -import java.util.ArrayList; -import java.util.Iterator; -import java.util.List; -import java.util.Map; +import java.util.*; import org.elasticsearch.hadoop.EsHadoopIllegalArgumentException; import org.elasticsearch.hadoop.serialization.FieldType; @@ -52,13 +49,25 @@ public static MappingSet parseTypelessMappings(Map content) { * @return MappingSet for that response. */ public static MappingSet parseMappings(Map content, boolean includeTypeName) { + return parseMappings(content, includeTypeName, Collections.emptyList()); + } + + /** + * Convert the deserialized mapping request body into an object + * @param content entire mapping request body for all indices and types + * @param includeTypeName true if the given content to be parsed includes type names within the structure, + * or false if it is in the typeless format + * @param includeFields list of field that should have mapping checked + * @return MappingSet for that response. + */ + public static MappingSet parseMappings(Map content, boolean includeTypeName, Collection includeFields) { Iterator> indices = content.entrySet().iterator(); List indexMappings = new ArrayList(); while(indices.hasNext()) { // These mappings are ordered by index, then optionally type. parseIndexMappings(indices.next(), indexMappings, includeTypeName); } - return new MappingSet(indexMappings); + return new MappingSet(indexMappings, includeFields); } private static void parseIndexMappings(Map.Entry indexToMappings, List collector, boolean includeTypeName) { diff --git a/mr/src/main/java/org/elasticsearch/hadoop/serialization/dto/mapping/MappingSet.java b/mr/src/main/java/org/elasticsearch/hadoop/serialization/dto/mapping/MappingSet.java index 8e1be4f09..438165948 100644 --- a/mr/src/main/java/org/elasticsearch/hadoop/serialization/dto/mapping/MappingSet.java +++ b/mr/src/main/java/org/elasticsearch/hadoop/serialization/dto/mapping/MappingSet.java @@ -20,12 +20,7 @@ package org.elasticsearch.hadoop.serialization.dto.mapping; import java.io.Serializable; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.LinkedHashMap; -import java.util.LinkedHashSet; -import java.util.List; -import java.util.Map; +import java.util.*; import org.elasticsearch.hadoop.EsHadoopIllegalArgumentException; import org.elasticsearch.hadoop.serialization.FieldType; @@ -46,7 +41,7 @@ public class MappingSet implements Serializable { private final Map> indexTypeMap = new HashMap>(); private final Mapping resolvedSchema; - public MappingSet(List mappings) { + public MappingSet(List mappings, Collection includeFields) { if (mappings.isEmpty()) { this.empty = true; this.resolvedSchema = new Mapping(RESOLVED_INDEX_NAME, RESOLVED_MAPPING_NAME, Field.NO_FIELDS); @@ -78,15 +73,15 @@ public MappingSet(List mappings) { mappingsToSchema.put(typeName, mapping); } - this.resolvedSchema = mergeMappings(mappings); + this.resolvedSchema = mergeMappings(mappings, includeFields); } } - private static Mapping mergeMappings(List mappings) { + private static Mapping mergeMappings(List mappings, Collection includeFields) { Map fieldMap = new LinkedHashMap(); for (Mapping mapping: mappings) { for (Field field : mapping.getFields()) { - addToFieldTable(field, "", fieldMap); + addToFieldTable(field, "", fieldMap, includeFields); } } Field[] collapsed = collapseFields(fieldMap); @@ -94,10 +89,13 @@ private static Mapping mergeMappings(List mappings) { } @SuppressWarnings("unchecked") - private static void addToFieldTable(Field field, String parent, Map fieldTable) { + private static void addToFieldTable(Field field, String parent, Map fieldTable, Collection includeFields) { String fullName = parent + field.name(); Object[] entry = fieldTable.get(fullName); - if (entry == null) { + if (!includeFields.isEmpty() && !includeFields.contains(fullName)) { + return; + } + else if (entry == null) { // Haven't seen field yet. if (FieldType.isCompound(field.type())) { // visit its children @@ -105,7 +103,7 @@ private static void addToFieldTable(Field field, String parent, Map subTable = (Map)entry[1]; String prefix = fullName + "."; for (Field subField : field.properties()) { - addToFieldTable(subField, prefix, subTable); + addToFieldTable(subField, prefix, subTable, includeFields); } } } diff --git a/spark/core/src/main/scala/org/elasticsearch/spark/rdd/AbstractEsRDD.scala b/spark/core/src/main/scala/org/elasticsearch/spark/rdd/AbstractEsRDD.scala index 559664144..2f00c2eb4 100644 --- a/spark/core/src/main/scala/org/elasticsearch/spark/rdd/AbstractEsRDD.scala +++ b/spark/core/src/main/scala/org/elasticsearch/spark/rdd/AbstractEsRDD.scala @@ -19,7 +19,6 @@ package org.elasticsearch.spark.rdd; import JDKCollectionConvertersCompat.Converters._ -import scala.reflect.ClassTag import org.apache.commons.logging.LogFactory import org.apache.spark.Partition import org.apache.spark.SparkContext @@ -31,12 +30,12 @@ import org.elasticsearch.hadoop.rest.PartitionDefinition import org.elasticsearch.hadoop.util.ObjectUtils import org.elasticsearch.spark.cfg.SparkSettingsManager import org.elasticsearch.hadoop.rest.RestRepository - -import scala.annotation.meta.param +import org.elasticsearch.hadoop.serialization.dto.mapping.{Mapping, MappingSet} private[spark] abstract class AbstractEsRDD[T: ClassTag]( @(transient @param) sc: SparkContext, - val params: scala.collection.Map[String, String] = Map.empty) + val params: scala.collection.Map[String, String] = Map.empty, + @(transient @param) mapping: Mapping = null) extends RDD[T](sc, Nil) { private val init = { ObjectUtils.loadClass("org.elasticsearch.spark.rdd.CompatUtils", classOf[ObjectUtils].getClassLoader) } @@ -75,7 +74,7 @@ private[spark] abstract class AbstractEsRDD[T: ClassTag]( } @transient private[spark] lazy val esPartitions = { - RestService.findPartitions(esCfg, logger) + RestService.findPartitions(esCfg, logger, mapping) } } diff --git a/spark/sql-30/src/main/scala/org/elasticsearch/spark/sql/DefaultSource.scala b/spark/sql-30/src/main/scala/org/elasticsearch/spark/sql/DefaultSource.scala index 86ffbfa17..fb57e422d 100644 --- a/spark/sql-30/src/main/scala/org/elasticsearch/spark/sql/DefaultSource.scala +++ b/spark/sql-30/src/main/scala/org/elasticsearch/spark/sql/DefaultSource.scala @@ -80,6 +80,7 @@ import org.elasticsearch.hadoop.util.StringUtils import org.elasticsearch.hadoop.util.Version import org.elasticsearch.spark.cfg.SparkSettingsManager import org.elasticsearch.spark.serialization.ScalaValueWriter +import org.elasticsearch.spark.sql.SchemaUtils.{Schema, discoverMapping} import org.elasticsearch.spark.sql.streaming.EsSparkSqlStreamingSink import org.elasticsearch.spark.sql.streaming.SparkSqlStreamingConfigs import org.elasticsearch.spark.sql.streaming.StructuredStreamingVersionLock @@ -235,11 +236,11 @@ private[sql] case class ElasticsearchRelation(parameters: Map[String, String], @ conf } - @transient lazy val lazySchema = { SchemaUtils.discoverMapping(cfg) } + @transient lazy val lazySchema = SchemaUtils.discoverMapping(cfg, userSchema) @transient lazy val valueWriter = { new ScalaValueWriter } - override def schema = userSchema.getOrElse(lazySchema.struct) + override def schema: StructType = lazySchema.struct // TableScan def buildScan(): RDD[Row] = buildScan(Array.empty) diff --git a/spark/sql-30/src/main/scala/org/elasticsearch/spark/sql/ScalaEsRowRDD.scala b/spark/sql-30/src/main/scala/org/elasticsearch/spark/sql/ScalaEsRowRDD.scala index 7b545f15c..e9791f06b 100644 --- a/spark/sql-30/src/main/scala/org/elasticsearch/spark/sql/ScalaEsRowRDD.scala +++ b/spark/sql-30/src/main/scala/org/elasticsearch/spark/sql/ScalaEsRowRDD.scala @@ -41,7 +41,7 @@ private[spark] class ScalaEsRowRDD( @(transient @param) sc: SparkContext, params: Map[String, String] = Map.empty, schema: SchemaUtils.Schema) - extends AbstractEsRDD[Row](sc, params) { + extends AbstractEsRDD[Row](sc, params, schema.mapping) { override def compute(split: Partition, context: TaskContext): ScalaEsRowRDDIterator = { new ScalaEsRowRDDIterator(context, split.asInstanceOf[EsPartition].esPartition, schema) diff --git a/spark/sql-30/src/main/scala/org/elasticsearch/spark/sql/SchemaUtils.scala b/spark/sql-30/src/main/scala/org/elasticsearch/spark/sql/SchemaUtils.scala index 849ff5a78..fc1f5e366 100644 --- a/spark/sql-30/src/main/scala/org/elasticsearch/spark/sql/SchemaUtils.scala +++ b/spark/sql-30/src/main/scala/org/elasticsearch/spark/sql/SchemaUtils.scala @@ -23,26 +23,7 @@ import java.util.{LinkedHashSet => JHashSet} import java.util.{List => JList} import java.util.{Map => JMap} import java.util.Properties - -import scala.collection.JavaConverters.asScalaBufferConverter -import scala.collection.JavaConverters.propertiesAsScalaMapConverter -import scala.collection.mutable.ArrayBuffer -import org.apache.spark.sql.types.ArrayType -import org.apache.spark.sql.types.BinaryType -import org.apache.spark.sql.types.BooleanType -import org.apache.spark.sql.types.ByteType -import org.apache.spark.sql.types.DataType -import org.apache.spark.sql.types.DataTypes -import org.apache.spark.sql.types.DoubleType -import org.apache.spark.sql.types.FloatType -import org.apache.spark.sql.types.IntegerType -import org.apache.spark.sql.types.LongType -import org.apache.spark.sql.types.NullType -import org.apache.spark.sql.types.ShortType -import org.apache.spark.sql.types.StringType -import org.apache.spark.sql.types.StructField -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.types.TimestampType +import org.apache.spark.sql.types._ import org.elasticsearch.hadoop.EsHadoopIllegalArgumentException import org.elasticsearch.hadoop.cfg.InternalConfigurationOptions import org.elasticsearch.hadoop.cfg.Settings @@ -70,12 +51,7 @@ import org.elasticsearch.hadoop.serialization.FieldType.SHORT import org.elasticsearch.hadoop.serialization.FieldType.STRING import org.elasticsearch.hadoop.serialization.FieldType.TEXT import org.elasticsearch.hadoop.serialization.FieldType.WILDCARD -import org.elasticsearch.hadoop.serialization.dto.mapping.Field -import org.elasticsearch.hadoop.serialization.dto.mapping.GeoField -import org.elasticsearch.hadoop.serialization.dto.mapping.GeoPointType -import org.elasticsearch.hadoop.serialization.dto.mapping.GeoShapeType -import org.elasticsearch.hadoop.serialization.dto.mapping.Mapping -import org.elasticsearch.hadoop.serialization.dto.mapping.MappingUtils +import org.elasticsearch.hadoop.serialization.dto.mapping.{Field, GeoField, GeoPointType, GeoShapeType, Mapping, MappingSet, MappingUtils} import org.elasticsearch.hadoop.serialization.field.FieldFilter import org.elasticsearch.hadoop.serialization.field.FieldFilter.NumberedInclude import org.elasticsearch.hadoop.util.Assert @@ -85,24 +61,52 @@ import org.elasticsearch.hadoop.util.StringUtils import org.elasticsearch.spark.sql.Utils.ROOT_LEVEL_NAME import org.elasticsearch.spark.sql.Utils.ROW_INFO_ARRAY_PROPERTY import org.elasticsearch.spark.sql.Utils.ROW_INFO_ORDER_PROPERTY +import scala.annotation.tailrec private[sql] object SchemaUtils { case class Schema(mapping: Mapping, struct: StructType) - def discoverMapping(cfg: Settings): Schema = { - val (mapping, geoInfo) = discoverMappingAndGeoFields(cfg) + def discoverMapping(cfg: Settings, userSchema: Option[StructType] = None): Schema = { + val includeFields = structToColumnsNames(userSchema) + val (mapping, geoInfo) = discoverMappingAndGeoFields(cfg, includeFields) val struct = convertToStruct(mapping, geoInfo, cfg) Schema(mapping, struct) } - def discoverMappingAndGeoFields(cfg: Settings): (Mapping, JMap[String, GeoField]) = { + def structToColumnsNames(struct: Option[StructType]): Seq[String] = { + @tailrec + def getInnerMostType(dType: DataType): DataType = dType match { + case at: ArrayType => getInnerMostType(at.elementType) + case t => t + } + + @tailrec + def flattenFields(remaining: Seq[(String, DataType)], acc: Seq[String]): Seq[String] = remaining match { + case Nil => acc + case (name, dataType) :: tail => + getInnerMostType(dataType) match { + case s: StructType => + val nestedFields = s.fields.map(f => (s"$name.${f.name}", f.dataType)) + flattenFields(nestedFields ++ tail, acc :+ name) + case _ => + flattenFields(tail, name +: acc) + } + } + + struct match { + case None => Seq.empty + case Some(s) => flattenFields(s.fields.map(f => (f.name, f.dataType)), Seq.empty) + } + } + + def discoverMappingAndGeoFields(cfg: Settings, includeFields: Seq[String]): (Mapping, JMap[String, GeoField]) = { InitializationUtils.validateSettings(cfg) InitializationUtils.discoverClusterInfo(cfg, Utils.LOGGER) val repo = new RestRepository(cfg) try { if (repo.resourceExists(true)) { - var mappingSet = repo.getMappings + val mappingSet = repo.getMappings(includeFields.asJava) if (mappingSet == null || mappingSet.isEmpty) { throw new EsHadoopIllegalArgumentException(s"Cannot find mapping for ${cfg.getResourceRead} - one is required before using Spark SQL") }