## Licensed to the Apache Software Foundation (ASF) under one or more# contributor license agreements. See the NOTICE file distributed with# this work for additional information regarding copyright ownership.# The ASF licenses this file to You under the Apache License, Version 2.0# (the "License"); you may not use this file except in compliance with# the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.#importosimportsysimportitertoolsfrommultiprocessing.poolimportThreadPoolfromtypingimport(Any,Callable,Dict,Iterable,List,Optional,Sequence,Tuple,Type,Union,cast,overload,TYPE_CHECKING,)importnumpyasnpfrompysparkimportkeyword_only,since,SparkContext,inheritable_thread_targetfrompyspark.mlimportEstimator,Transformer,Modelfrompyspark.ml.commonimportinherit_doc,_py2java,_java2pyfrompyspark.ml.evaluationimportEvaluator,JavaEvaluatorfrompyspark.ml.paramimportParams,Param,TypeConvertersfrompyspark.ml.param.sharedimportHasCollectSubModels,HasParallelism,HasSeedfrompyspark.ml.utilimport(DefaultParamsReader,DefaultParamsWriter,MetaAlgorithmReadWrite,MLReadable,MLReader,MLWritable,MLWriter,JavaMLReader,JavaMLWriter,)frompyspark.ml.wrapperimportJavaParams,JavaEstimator,JavaWrapperfrompyspark.sql.functionsimportcol,lit,rand,UserDefinedFunctionfrompyspark.sql.typesimportBooleanTypefrompyspark.sql.dataframeimportDataFrameifTYPE_CHECKING:frompyspark.ml._typingimportParamMapfrompy4j.java_gatewayimportJavaObjectfrompy4j.java_collectionsimportJavaArray__all__=["ParamGridBuilder","CrossValidator","CrossValidatorModel","TrainValidationSplit","TrainValidationSplitModel",]def_parallelFitTasks(est:Estimator,train:DataFrame,eva:Evaluator,validation:DataFrame,epm:Sequence["ParamMap"],collectSubModel:bool,)->List[Callable[[],Tuple[int,float,Transformer]]]:""" Creates a list of callables which can be called from different threads to fit and evaluate an estimator in parallel. Each callable returns an `(index, metric)` pair. Parameters ---------- est : :py:class:`pyspark.ml.baseEstimator` he estimator to be fit. train : :py:class:`pyspark.sql.DataFrame` DataFrame, training data set, used for fitting. eva : :py:class:`pyspark.ml.evaluation.Evaluator` used to compute `metric` validation : :py:class:`pyspark.sql.DataFrame` DataFrame, validation data set, used for evaluation. epm : :py:class:`collections.abc.Sequence` Sequence of ParamMap, params maps to be used during fitting & evaluation. collectSubModel : bool Whether to collect sub model. Returns ------- tuple (int, float, subModel), an index into `epm` and the associated metric value. """modelIter=est.fitMultiple(train,epm)defsingleTask()->Tuple[int,float,Transformer]:index,model=next(modelIter)# TODO: duplicate evaluator to take extra params from input# Note: Supporting tuning params in evaluator need update method# `MetaAlgorithmReadWrite.getAllNestedStages`, make it return# all nested stages and evaluatorsmetric=eva.evaluate(model.transform(validation,epm[index]))returnindex,metric,modelifcollectSubModelelseNonereturn[singleTask]*len(epm)
[docs]classParamGridBuilder:r""" Builder for a param grid used in grid search-based model selection. .. versionadded:: 1.4.0 Examples -------- >>> from pyspark.ml.classification import LogisticRegression >>> lr = LogisticRegression() >>> output = ParamGridBuilder() \ ... .baseOn({lr.labelCol: 'l'}) \ ... .baseOn([lr.predictionCol, 'p']) \ ... .addGrid(lr.regParam, [1.0, 2.0]) \ ... .addGrid(lr.maxIter, [1, 5]) \ ... .build() >>> expected = [ ... {lr.regParam: 1.0, lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'}, ... {lr.regParam: 2.0, lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'}, ... {lr.regParam: 1.0, lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}, ... {lr.regParam: 2.0, lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}] >>> len(output) == len(expected) True >>> all([m in expected for m in output]) True """def__init__(self)->None:self._param_grid:"ParamMap"={}
[docs]@since("1.4.0")defaddGrid(self,param:Param[Any],values:List[Any])->"ParamGridBuilder":""" Sets the given parameters in this grid to fixed values. param must be an instance of Param associated with an instance of Params (such as Estimator or Transformer). """ifisinstance(param,Param):self._param_grid[param]=valueselse:raiseTypeError("param must be an instance of Param")returnself
[docs]@since("1.4.0")defbaseOn(self,*args:Union["ParamMap",Tuple[Param,Any]])->"ParamGridBuilder":""" Sets the given parameters in this grid to fixed values. Accepts either a parameter dictionary or a list of (parameter, value) pairs. """ifisinstance(args[0],dict):self.baseOn(*args[0].items())else:for(param,value)inargs:self.addGrid(param,[value])returnself
[docs]@since("1.4.0")defbuild(self)->List["ParamMap"]:""" Builds and returns all combinations of parameters specified by the param grid. """keys=self._param_grid.keys()grid_values=self._param_grid.values()defto_key_value_pairs(keys:Iterable[Param],values:Iterable[Any])->Sequence[Tuple[Param,Any]]:return[(key,key.typeConverter(value))forkey,valueinzip(keys,values)]return[dict(to_key_value_pairs(keys,prod))forprodinitertools.product(*grid_values)]
class_ValidatorParams(HasSeed):""" Common params for TrainValidationSplit and CrossValidator. """estimator:Param[Estimator]=Param(Params._dummy(),"estimator","estimator to be cross-validated")estimatorParamMaps:Param[List["ParamMap"]]=Param(Params._dummy(),"estimatorParamMaps","estimator param maps")evaluator:Param[Evaluator]=Param(Params._dummy(),"evaluator","evaluator used to select hyper-parameters that maximize the validator metric",)@since("2.0.0")defgetEstimator(self)->Estimator:""" Gets the value of estimator or its default value. """returnself.getOrDefault(self.estimator)@since("2.0.0")defgetEstimatorParamMaps(self)->List["ParamMap"]:""" Gets the value of estimatorParamMaps or its default value. """returnself.getOrDefault(self.estimatorParamMaps)@since("2.0.0")defgetEvaluator(self)->Evaluator:""" Gets the value of evaluator or its default value. """returnself.getOrDefault(self.evaluator)@classmethoddef_from_java_impl(cls,java_stage:"JavaObject")->Tuple[Estimator,List["ParamMap"],Evaluator]:""" Return Python estimator, estimatorParamMaps, and evaluator from a Java ValidatorParams. """# Load information from java_stage to the instance.estimator:Estimator=JavaParams._from_java(java_stage.getEstimator())evaluator:Evaluator=JavaParams._from_java(java_stage.getEvaluator())ifisinstance(estimator,JavaEstimator):epms=[estimator._transfer_param_map_from_java(epm)forepminjava_stage.getEstimatorParamMaps()]elifMetaAlgorithmReadWrite.isMetaEstimator(estimator):# Meta estimator such as Pipeline, OneVsRestepms=_ValidatorSharedReadWrite.meta_estimator_transfer_param_maps_from_java(estimator,java_stage.getEstimatorParamMaps())else:raiseValueError("Unsupported estimator used in tuning: "+str(estimator))returnestimator,epms,evaluatordef_to_java_impl(self)->Tuple["JavaObject","JavaObject","JavaObject"]:""" Return Java estimator, estimatorParamMaps, and evaluator from this Python instance. """gateway=SparkContext._gatewayassertgatewayisnotNoneandSparkContext._jvmisnotNonecls=SparkContext._jvm.org.apache.spark.ml.param.ParamMapestimator=self.getEstimator()ifisinstance(estimator,JavaEstimator):java_epms=gateway.new_array(cls,len(self.getEstimatorParamMaps()))foridx,epminenumerate(self.getEstimatorParamMaps()):java_epms[idx]=estimator._transfer_param_map_to_java(epm)elifMetaAlgorithmReadWrite.isMetaEstimator(estimator):# Meta estimator such as Pipeline, OneVsRestjava_epms=_ValidatorSharedReadWrite.meta_estimator_transfer_param_maps_to_java(estimator,self.getEstimatorParamMaps())else:raiseValueError("Unsupported estimator used in tuning: "+str(estimator))java_estimator=cast(JavaEstimator,self.getEstimator())._to_java()java_evaluator=cast(JavaEvaluator,self.getEvaluator())._to_java()returnjava_estimator,java_epms,java_evaluatorclass_ValidatorSharedReadWrite:@staticmethoddefmeta_estimator_transfer_param_maps_to_java(pyEstimator:Estimator,pyParamMaps:Sequence["ParamMap"])->"JavaArray":pyStages=MetaAlgorithmReadWrite.getAllNestedStages(pyEstimator)stagePairs=list(map(lambdastage:(stage,cast(JavaParams,stage)._to_java()),pyStages))sc=SparkContext._active_spark_contextassert(scisnotNoneandSparkContext._jvmisnotNoneandSparkContext._gatewayisnotNone)paramMapCls=SparkContext._jvm.org.apache.spark.ml.param.ParamMapjavaParamMaps=SparkContext._gateway.new_array(paramMapCls,len(pyParamMaps))foridx,pyParamMapinenumerate(pyParamMaps):javaParamMap=JavaWrapper._new_java_obj("org.apache.spark.ml.param.ParamMap")forpyParam,pyValueinpyParamMap.items():javaParam=NoneforpyStage,javaStageinstagePairs:ifpyStage._testOwnParam(pyParam.parent,pyParam.name):javaParam=javaStage.getParam(pyParam.name)breakifjavaParamisNone:raiseValueError("Resolve param in estimatorParamMaps failed: "+str(pyParam))ifisinstance(pyValue,Params)andhasattr(pyValue,"_to_java"):javaValue=cast(JavaParams,pyValue)._to_java()else:javaValue=_py2java(sc,pyValue)pair=javaParam.w(javaValue)javaParamMap.put([pair])javaParamMaps[idx]=javaParamMapreturnjavaParamMaps@staticmethoddefmeta_estimator_transfer_param_maps_from_java(pyEstimator:Estimator,javaParamMaps:"JavaArray")->List["ParamMap"]:pyStages=MetaAlgorithmReadWrite.getAllNestedStages(pyEstimator)stagePairs=list(map(lambdastage:(stage,cast(JavaParams,stage)._to_java()),pyStages))sc=SparkContext._active_spark_contextassertscisnotNoneandsc._jvmisnotNonepyParamMaps=[]forjavaParamMapinjavaParamMaps:pyParamMap=dict()forjavaPairinjavaParamMap.toList():javaParam=javaPair.param()pyParam=NoneforpyStage,javaStageinstagePairs:ifpyStage._testOwnParam(javaParam.parent(),javaParam.name()):pyParam=pyStage.getParam(javaParam.name())ifpyParamisNone:raiseValueError("Resolve param in estimatorParamMaps failed: "+javaParam.parent()+"."+javaParam.name())javaValue=javaPair.value()pyValue:Anyifsc._jvm.Class.forName("org.apache.spark.ml.util.DefaultParamsWritable").isInstance(javaValue):pyValue=JavaParams._from_java(javaValue)else:pyValue=_java2py(sc,javaValue)pyParamMap[pyParam]=pyValuepyParamMaps.append(pyParamMap)returnpyParamMaps@staticmethoddefis_java_convertible(instance:_ValidatorParams)->bool:allNestedStages=MetaAlgorithmReadWrite.getAllNestedStages(instance.getEstimator())evaluator_convertible=isinstance(instance.getEvaluator(),JavaParams)estimator_convertible=all(map(lambdastage:hasattr(stage,"_to_java"),allNestedStages))returnestimator_convertibleandevaluator_convertible@staticmethoddefsaveImpl(path:str,instance:_ValidatorParams,sc:SparkContext,extraMetadata:Optional[Dict[str,Any]]=None,)->None:numParamsNotJson=0jsonEstimatorParamMaps=[]forparamMapininstance.getEstimatorParamMaps():jsonParamMap=[]forp,vinparamMap.items():jsonParam:Dict[str,Any]={"parent":p.parent,"name":p.name}if((isinstance(v,Estimator)andnotMetaAlgorithmReadWrite.isMetaEstimator(v))orisinstance(v,Transformer)orisinstance(v,Evaluator)):relative_path=f"epm_{p.name}{numParamsNotJson}"param_path=os.path.join(path,relative_path)numParamsNotJson+=1cast(MLWritable,v).save(param_path)jsonParam["value"]=relative_pathjsonParam["isJson"]=Falseelifisinstance(v,MLWritable):raiseRuntimeError("ValidatorSharedReadWrite.saveImpl does not handle parameters of type: ""MLWritable that are not Estimator/Evaluator/Transformer, and if parameter ""is estimator, it cannot be meta estimator such as Validator or OneVsRest")else:jsonParam["value"]=vjsonParam["isJson"]=TruejsonParamMap.append(jsonParam)jsonEstimatorParamMaps.append(jsonParamMap)skipParams=["estimator","evaluator","estimatorParamMaps"]jsonParams=DefaultParamsWriter.extractJsonParams(instance,skipParams)jsonParams["estimatorParamMaps"]=jsonEstimatorParamMapsDefaultParamsWriter.saveMetadata(instance,path,sc,extraMetadata,jsonParams)evaluatorPath=os.path.join(path,"evaluator")cast(MLWritable,instance.getEvaluator()).save(evaluatorPath)estimatorPath=os.path.join(path,"estimator")cast(MLWritable,instance.getEstimator()).save(estimatorPath)@staticmethoddefload(path:str,sc:SparkContext,metadata:Dict[str,Any])->Tuple[Dict[str,Any],Estimator,Evaluator,List["ParamMap"]]:evaluatorPath=os.path.join(path,"evaluator")evaluator:Evaluator=DefaultParamsReader.loadParamsInstance(evaluatorPath,sc)estimatorPath=os.path.join(path,"estimator")estimator:Estimator=DefaultParamsReader.loadParamsInstance(estimatorPath,sc)uidToParams=MetaAlgorithmReadWrite.getUidMap(estimator)uidToParams[evaluator.uid]=evaluatorjsonEstimatorParamMaps=metadata["paramMap"]["estimatorParamMaps"]estimatorParamMaps=[]forjsonParamMapinjsonEstimatorParamMaps:paramMap={}forjsonParaminjsonParamMap:est=uidToParams[jsonParam["parent"]]param=getattr(est,jsonParam["name"])if"isJson"notinjsonParamor("isJson"injsonParamandjsonParam["isJson"]):value=jsonParam["value"]else:relativePath=jsonParam["value"]valueSavedPath=os.path.join(path,relativePath)value=DefaultParamsReader.loadParamsInstance(valueSavedPath,sc)paramMap[param]=valueestimatorParamMaps.append(paramMap)returnmetadata,estimator,evaluator,estimatorParamMaps@staticmethoddefvalidateParams(instance:_ValidatorParams)->None:estiamtor=instance.getEstimator()evaluator=instance.getEvaluator()uidMap=MetaAlgorithmReadWrite.getUidMap(estiamtor)forelemin[evaluator]+list(uidMap.values()):# type: ignore[arg-type]ifnotisinstance(elem,MLWritable):raiseValueError(f"Validator write will fail because it contains {elem.uid} "f"which is not writable.")estimatorParamMaps=instance.getEstimatorParamMaps()paramErr=("Validator save requires all Params in estimatorParamMaps to apply to ""its Estimator, An extraneous Param was found: ")forparamMapinestimatorParamMaps:forparaminparamMap:ifparam.parentnotinuidMap:raiseValueError(paramErr+repr(param))@staticmethoddefgetValidatorModelWriterPersistSubModelsParam(writer:MLWriter)->bool:if"persistsubmodels"inwriter.optionMap:persistSubModelsParam=writer.optionMap["persistsubmodels"].lower()ifpersistSubModelsParam=="true":returnTrueelifpersistSubModelsParam=="false":returnFalseelse:raiseValueError(f"persistSubModels option value {persistSubModelsParam} is invalid, "f"the possible values are True, 'True' or False, 'False'")else:returnwriter.instance.subModelsisnotNone# type: ignore[attr-defined]_save_with_persist_submodels_no_submodels_found_err:str=("When persisting tuning models, you can only set persistSubModels to true if the tuning ""was done with collectSubModels set to true. To save the sub-models, try rerunning fitting ""with collectSubModels set to true.")@inherit_docclassCrossValidatorReader(MLReader["CrossValidator"]):def__init__(self,cls:Type["CrossValidator"]):super(CrossValidatorReader,self).__init__()self.cls=clsdefload(self,path:str)->"CrossValidator":metadata=DefaultParamsReader.loadMetadata(path,self.sc)ifnotDefaultParamsReader.isPythonParamsInstance(metadata):returnJavaMLReader(self.cls).load(path)# type: ignore[arg-type]else:metadata,estimator,evaluator,estimatorParamMaps=_ValidatorSharedReadWrite.load(path,self.sc,metadata)cv=CrossValidator(estimator=estimator,estimatorParamMaps=estimatorParamMaps,evaluator=evaluator)cv=cv._resetUid(metadata["uid"])DefaultParamsReader.getAndSetParams(cv,metadata,skipParams=["estimatorParamMaps"])returncv@inherit_docclassCrossValidatorWriter(MLWriter):def__init__(self,instance:"CrossValidator"):super(CrossValidatorWriter,self).__init__()self.instance=instancedefsaveImpl(self,path:str)->None:_ValidatorSharedReadWrite.validateParams(self.instance)_ValidatorSharedReadWrite.saveImpl(path,self.instance,self.sc)@inherit_docclassCrossValidatorModelReader(MLReader["CrossValidatorModel"]):def__init__(self,cls:Type["CrossValidatorModel"]):super(CrossValidatorModelReader,self).__init__()self.cls=clsdefload(self,path:str)->"CrossValidatorModel":metadata=DefaultParamsReader.loadMetadata(path,self.sc)ifnotDefaultParamsReader.isPythonParamsInstance(metadata):returnJavaMLReader(self.cls).load(path)# type: ignore[arg-type]else:metadata,estimator,evaluator,estimatorParamMaps=_ValidatorSharedReadWrite.load(path,self.sc,metadata)numFolds=metadata["paramMap"]["numFolds"]bestModelPath=os.path.join(path,"bestModel")bestModel:Model=DefaultParamsReader.loadParamsInstance(bestModelPath,self.sc)avgMetrics=metadata["avgMetrics"]if"stdMetrics"inmetadata:stdMetrics=metadata["stdMetrics"]else:stdMetrics=NonepersistSubModels=("persistSubModels"inmetadata)andmetadata["persistSubModels"]ifpersistSubModels:subModels=[[None]*len(estimatorParamMaps)]*numFoldsforsplitIndexinrange(numFolds):forparamIndexinrange(len(estimatorParamMaps)):modelPath=os.path.join(path,"subModels",f"fold{splitIndex}",f"{paramIndex}")subModels[splitIndex][paramIndex]=DefaultParamsReader.loadParamsInstance(modelPath,self.sc)else:subModels=NonecvModel=CrossValidatorModel(bestModel,avgMetrics=avgMetrics,subModels=cast(List[List[Model]],subModels),stdMetrics=stdMetrics,)cvModel=cvModel._resetUid(metadata["uid"])cvModel.set(cvModel.estimator,estimator)cvModel.set(cvModel.estimatorParamMaps,estimatorParamMaps)cvModel.set(cvModel.evaluator,evaluator)DefaultParamsReader.getAndSetParams(cvModel,metadata,skipParams=["estimatorParamMaps"])returncvModel@inherit_docclassCrossValidatorModelWriter(MLWriter):def__init__(self,instance:"CrossValidatorModel"):super(CrossValidatorModelWriter,self).__init__()self.instance=instancedefsaveImpl(self,path:str)->None:_ValidatorSharedReadWrite.validateParams(self.instance)instance=self.instancepersistSubModels=_ValidatorSharedReadWrite.getValidatorModelWriterPersistSubModelsParam(self)extraMetadata={"avgMetrics":instance.avgMetrics,"persistSubModels":persistSubModels}ifinstance.stdMetrics:extraMetadata["stdMetrics"]=instance.stdMetrics_ValidatorSharedReadWrite.saveImpl(path,instance,self.sc,extraMetadata=extraMetadata)bestModelPath=os.path.join(path,"bestModel")cast(MLWritable,instance.bestModel).save(bestModelPath)ifpersistSubModels:ifinstance.subModelsisNone:raiseValueError(_save_with_persist_submodels_no_submodels_found_err)subModelsPath=os.path.join(path,"subModels")forsplitIndexinrange(instance.getNumFolds()):splitPath=os.path.join(subModelsPath,f"fold{splitIndex}")forparamIndexinrange(len(instance.getEstimatorParamMaps())):modelPath=os.path.join(splitPath,f"{paramIndex}")cast(MLWritable,instance.subModels[splitIndex][paramIndex]).save(modelPath)class_CrossValidatorParams(_ValidatorParams):""" Params for :py:class:`CrossValidator` and :py:class:`CrossValidatorModel`. .. versionadded:: 3.0.0 """numFolds:Param[int]=Param(Params._dummy(),"numFolds","number of folds for cross validation",typeConverter=TypeConverters.toInt,)foldCol:Param[str]=Param(Params._dummy(),"foldCol","Param for the column name of user "+"specified fold number. Once this is specified, :py:class:`CrossValidator` "+"won't do random k-fold split. Note that this column should be integer type "+"with range [0, numFolds) and Spark will throw exception on out-of-range "+"fold numbers.",typeConverter=TypeConverters.toString,)def__init__(self,*args:Any):super(_CrossValidatorParams,self).__init__(*args)self._setDefault(numFolds=3,foldCol="")@since("1.4.0")defgetNumFolds(self)->int:""" Gets the value of numFolds or its default value. """returnself.getOrDefault(self.numFolds)@since("3.1.0")defgetFoldCol(self)->str:""" Gets the value of foldCol or its default value. """returnself.getOrDefault(self.foldCol)
[docs]classCrossValidator(Estimator["CrossValidatorModel"],_CrossValidatorParams,HasParallelism,HasCollectSubModels,MLReadable["CrossValidator"],MLWritable,):""" K-fold cross validation performs model selection by splitting the dataset into a set of non-overlapping randomly partitioned folds which are used as separate training and test datasets e.g., with k=3 folds, K-fold cross validation will generate 3 (training, test) dataset pairs, each of which uses 2/3 of the data for training and 1/3 for testing. Each fold is used as the test set exactly once. .. versionadded:: 1.4.0 Examples -------- >>> from pyspark.ml.classification import LogisticRegression >>> from pyspark.ml.evaluation import BinaryClassificationEvaluator >>> from pyspark.ml.linalg import Vectors >>> from pyspark.ml.tuning import CrossValidator, ParamGridBuilder, CrossValidatorModel >>> import tempfile >>> dataset = spark.createDataFrame( ... [(Vectors.dense([0.0]), 0.0), ... (Vectors.dense([0.4]), 1.0), ... (Vectors.dense([0.5]), 0.0), ... (Vectors.dense([0.6]), 1.0), ... (Vectors.dense([1.0]), 1.0)] * 10, ... ["features", "label"]) >>> lr = LogisticRegression() >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() >>> evaluator = BinaryClassificationEvaluator() >>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator, ... parallelism=2) >>> cvModel = cv.fit(dataset) >>> cvModel.getNumFolds() 3 >>> cvModel.avgMetrics[0] 0.5 >>> path = tempfile.mkdtemp() >>> model_path = path + "/model" >>> cvModel.write().save(model_path) >>> cvModelRead = CrossValidatorModel.read().load(model_path) >>> cvModelRead.avgMetrics [0.5, ... >>> evaluator.evaluate(cvModel.transform(dataset)) 0.8333... >>> evaluator.evaluate(cvModelRead.transform(dataset)) 0.8333... """_input_kwargs:Dict[str,Any]@keyword_onlydef__init__(self,*,estimator:Optional[Estimator]=None,estimatorParamMaps:Optional[List["ParamMap"]]=None,evaluator:Optional[Evaluator]=None,numFolds:int=3,seed:Optional[int]=None,parallelism:int=1,collectSubModels:bool=False,foldCol:str="",)->None:""" __init__(self, \\*, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\ seed=None, parallelism=1, collectSubModels=False, foldCol="") """super(CrossValidator,self).__init__()self._setDefault(parallelism=1)kwargs=self._input_kwargsself._set(**kwargs)
[docs]@since("2.0.0")defsetEstimator(self,value:Estimator)->"CrossValidator":""" Sets the value of :py:attr:`estimator`. """returnself._set(estimator=value)
[docs]@since("2.0.0")defsetEstimatorParamMaps(self,value:List["ParamMap"])->"CrossValidator":""" Sets the value of :py:attr:`estimatorParamMaps`. """returnself._set(estimatorParamMaps=value)
[docs]@since("2.0.0")defsetEvaluator(self,value:Evaluator)->"CrossValidator":""" Sets the value of :py:attr:`evaluator`. """returnself._set(evaluator=value)
[docs]@since("1.4.0")defsetNumFolds(self,value:int)->"CrossValidator":""" Sets the value of :py:attr:`numFolds`. """returnself._set(numFolds=value)
[docs]@since("3.1.0")defsetFoldCol(self,value:str)->"CrossValidator":""" Sets the value of :py:attr:`foldCol`. """returnself._set(foldCol=value)
[docs]defsetSeed(self,value:int)->"CrossValidator":""" Sets the value of :py:attr:`seed`. """returnself._set(seed=value)
[docs]defsetParallelism(self,value:int)->"CrossValidator":""" Sets the value of :py:attr:`parallelism`. """returnself._set(parallelism=value)
[docs]defsetCollectSubModels(self,value:bool)->"CrossValidator":""" Sets the value of :py:attr:`collectSubModels`. """returnself._set(collectSubModels=value)
@staticmethoddef_gen_avg_and_std_metrics(metrics_all:List[List[float]])->Tuple[List[float],List[float]]:avg_metrics=np.mean(metrics_all,axis=0)std_metrics=np.std(metrics_all,axis=0)returnlist(avg_metrics),list(std_metrics)def_fit(self,dataset:DataFrame)->"CrossValidatorModel":est=self.getOrDefault(self.estimator)epm=self.getOrDefault(self.estimatorParamMaps)numModels=len(epm)eva=self.getOrDefault(self.evaluator)nFolds=self.getOrDefault(self.numFolds)metrics_all=[[0.0]*numModelsforiinrange(nFolds)]pool=ThreadPool(processes=min(self.getParallelism(),numModels))subModels=NonecollectSubModelsParam=self.getCollectSubModels()ifcollectSubModelsParam:subModels=[[Noneforjinrange(numModels)]foriinrange(nFolds)]datasets=self._kFold(dataset)foriinrange(nFolds):validation=datasets[i][1].cache()train=datasets[i][0].cache()tasks=map(inheritable_thread_target,_parallelFitTasks(est,train,eva,validation,epm,collectSubModelsParam),)forj,metric,subModelinpool.imap_unordered(lambdaf:f(),tasks):metrics_all[i][j]=metricifcollectSubModelsParam:assertsubModelsisnotNonesubModels[i][j]=subModelvalidation.unpersist()train.unpersist()metrics,std_metrics=CrossValidator._gen_avg_and_std_metrics(metrics_all)ifeva.isLargerBetter():bestIndex=np.argmax(metrics)else:bestIndex=np.argmin(metrics)bestModel=est.fit(dataset,epm[bestIndex])returnself._copyValues(CrossValidatorModel(bestModel,metrics,cast(List[List[Model]],subModels),std_metrics))def_kFold(self,dataset:DataFrame)->List[Tuple[DataFrame,DataFrame]]:nFolds=self.getOrDefault(self.numFolds)foldCol=self.getOrDefault(self.foldCol)datasets=[]ifnotfoldCol:# Do random k-fold split.seed=self.getOrDefault(self.seed)h=1.0/nFoldsrandCol=self.uid+"_rand"df=dataset.select("*",rand(seed).alias(randCol))foriinrange(nFolds):validateLB=i*hvalidateUB=(i+1)*hcondition=(df[randCol]>=validateLB)&(df[randCol]<validateUB)validation=df.filter(condition)train=df.filter(~condition)datasets.append((train,validation))else:# Use user-specified fold numbers.defchecker(foldNum:int)->bool:iffoldNum<0orfoldNum>=nFolds:raiseValueError("Fold number must be in range [0, %s), but got %s."%(nFolds,foldNum))returnTruechecker_udf=UserDefinedFunction(checker,BooleanType())foriinrange(nFolds):training=dataset.filter(checker_udf(dataset[foldCol])&(col(foldCol)!=lit(i)))validation=dataset.filter(checker_udf(dataset[foldCol])&(col(foldCol)==lit(i)))iftraining.rdd.getNumPartitions()==0orlen(training.take(1))==0:raiseValueError("The training data at fold %s is empty."%i)ifvalidation.rdd.getNumPartitions()==0orlen(validation.take(1))==0:raiseValueError("The validation data at fold %s is empty."%i)datasets.append((training,validation))returndatasets
[docs]defcopy(self,extra:Optional["ParamMap"]=None)->"CrossValidator":""" Creates a copy of this instance with a randomly generated uid and some extra params. This copies creates a deep copy of the embedded paramMap, and copies the embedded and extra parameters over. .. versionadded:: 1.4.0 Parameters ---------- extra : dict, optional Extra parameters to copy to the new instance Returns ------- :py:class:`CrossValidator` Copy of this instance """ifextraisNone:extra=dict()newCV=Params.copy(self,extra)ifself.isSet(self.estimator):newCV.setEstimator(self.getEstimator().copy(extra))# estimatorParamMaps remain the sameifself.isSet(self.evaluator):newCV.setEvaluator(self.getEvaluator().copy(extra))returnnewCV
[docs]@since("2.3.0")defwrite(self)->MLWriter:"""Returns an MLWriter instance for this ML instance."""if_ValidatorSharedReadWrite.is_java_convertible(self):returnJavaMLWriter(self)# type: ignore[arg-type]returnCrossValidatorWriter(self)
[docs]@classmethod@since("2.3.0")defread(cls)->CrossValidatorReader:"""Returns an MLReader instance for this class."""returnCrossValidatorReader(cls)
@classmethoddef_from_java(cls,java_stage:"JavaObject")->"CrossValidator":""" Given a Java CrossValidator, create and return a Python wrapper of it. Used for ML persistence. """estimator,epms,evaluator=super(CrossValidator,cls)._from_java_impl(java_stage)numFolds=java_stage.getNumFolds()seed=java_stage.getSeed()parallelism=java_stage.getParallelism()collectSubModels=java_stage.getCollectSubModels()foldCol=java_stage.getFoldCol()# Create a new instance of this stage.py_stage=cls(estimator=estimator,estimatorParamMaps=epms,evaluator=evaluator,numFolds=numFolds,seed=seed,parallelism=parallelism,collectSubModels=collectSubModels,foldCol=foldCol,)py_stage._resetUid(java_stage.uid())returnpy_stagedef_to_java(self)->"JavaObject":""" Transfer this instance to a Java CrossValidator. Used for ML persistence. Returns ------- py4j.java_gateway.JavaObject Java object equivalent to this instance. """estimator,epms,evaluator=super(CrossValidator,self)._to_java_impl()_java_obj=JavaParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidator",self.uid)_java_obj.setEstimatorParamMaps(epms)_java_obj.setEvaluator(evaluator)_java_obj.setEstimator(estimator)_java_obj.setSeed(self.getSeed())_java_obj.setNumFolds(self.getNumFolds())_java_obj.setParallelism(self.getParallelism())_java_obj.setCollectSubModels(self.getCollectSubModels())_java_obj.setFoldCol(self.getFoldCol())return_java_obj
[docs]classCrossValidatorModel(Model,_CrossValidatorParams,MLReadable["CrossValidatorModel"],MLWritable):""" CrossValidatorModel contains the model with the highest average cross-validation metric across folds and uses this model to transform input data. CrossValidatorModel also tracks the metrics for each param map evaluated. .. versionadded:: 1.4.0 Notes ----- Since version 3.3.0, CrossValidatorModel contains a new attribute "stdMetrics", which represent standard deviation of metrics for each paramMap in CrossValidator.estimatorParamMaps. """def__init__(self,bestModel:Model,avgMetrics:Optional[List[float]]=None,subModels:Optional[List[List[Model]]]=None,stdMetrics:Optional[List[float]]=None,):super(CrossValidatorModel,self).__init__()#: best model from cross validationself.bestModel=bestModel#: Average cross-validation metrics for each paramMap in#: CrossValidator.estimatorParamMaps, in the corresponding order.self.avgMetrics=avgMetricsor[]#: sub model list from cross validationself.subModels=subModels#: standard deviation of metrics for each paramMap in#: CrossValidator.estimatorParamMaps, in the corresponding order.self.stdMetrics=stdMetricsor[]def_transform(self,dataset:DataFrame)->DataFrame:returnself.bestModel.transform(dataset)
[docs]defcopy(self,extra:Optional["ParamMap"]=None)->"CrossValidatorModel":""" Creates a copy of this instance with a randomly generated uid and some extra params. This copies the underlying bestModel, creates a deep copy of the embedded paramMap, and copies the embedded and extra parameters over. It does not copy the extra Params into the subModels. .. versionadded:: 1.4.0 Parameters ---------- extra : dict, optional Extra parameters to copy to the new instance Returns ------- :py:class:`CrossValidatorModel` Copy of this instance """ifextraisNone:extra=dict()bestModel=self.bestModel.copy(extra)avgMetrics=list(self.avgMetrics)assertself.subModelsisnotNonesubModels=[[sub_model.copy()forsub_modelinfold_sub_models]forfold_sub_modelsinself.subModels]stdMetrics=list(self.stdMetrics)returnself._copyValues(CrossValidatorModel(bestModel,avgMetrics,subModels,stdMetrics),extra=extra)
[docs]@since("2.3.0")defwrite(self)->MLWriter:"""Returns an MLWriter instance for this ML instance."""if_ValidatorSharedReadWrite.is_java_convertible(self):returnJavaMLWriter(self)# type: ignore[arg-type]returnCrossValidatorModelWriter(self)
[docs]@classmethod@since("2.3.0")defread(cls)->CrossValidatorModelReader:"""Returns an MLReader instance for this class."""returnCrossValidatorModelReader(cls)
@classmethoddef_from_java(cls,java_stage:"JavaObject")->"CrossValidatorModel":""" Given a Java CrossValidatorModel, create and return a Python wrapper of it. Used for ML persistence. """sc=SparkContext._active_spark_contextassertscisnotNonebestModel:Model=JavaParams._from_java(java_stage.bestModel())avgMetrics=_java2py(sc,java_stage.avgMetrics())estimator,epms,evaluator=super(CrossValidatorModel,cls)._from_java_impl(java_stage)py_stage=cls(bestModel=bestModel,avgMetrics=avgMetrics)params={"evaluator":evaluator,"estimator":estimator,"estimatorParamMaps":epms,"numFolds":java_stage.getNumFolds(),"foldCol":java_stage.getFoldCol(),"seed":java_stage.getSeed(),}forparam_name,param_valinparams.items():py_stage=py_stage._set(**{param_name:param_val})ifjava_stage.hasSubModels():py_stage.subModels=[[JavaParams._from_java(sub_model)forsub_modelinfold_sub_models]forfold_sub_modelsinjava_stage.subModels()]py_stage._resetUid(java_stage.uid())returnpy_stagedef_to_java(self)->"JavaObject":""" Transfer this instance to a Java CrossValidatorModel. Used for ML persistence. Returns ------- py4j.java_gateway.JavaObject Java object equivalent to this instance. """sc=SparkContext._active_spark_contextassertscisnotNone_java_obj=JavaParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidatorModel",self.uid,cast(JavaParams,self.bestModel)._to_java(),_py2java(sc,self.avgMetrics),)estimator,epms,evaluator=super(CrossValidatorModel,self)._to_java_impl()params={"evaluator":evaluator,"estimator":estimator,"estimatorParamMaps":epms,"numFolds":self.getNumFolds(),"foldCol":self.getFoldCol(),"seed":self.getSeed(),}forparam_name,param_valinparams.items():java_param=_java_obj.getParam(param_name)pair=java_param.w(param_val)_java_obj.set(pair)ifself.subModelsisnotNone:java_sub_models=[[cast(JavaParams,sub_model)._to_java()forsub_modelinfold_sub_models]forfold_sub_modelsinself.subModels]_java_obj.setSubModels(java_sub_models)return_java_obj
@inherit_docclassTrainValidationSplitReader(MLReader["TrainValidationSplit"]):def__init__(self,cls:Type["TrainValidationSplit"]):super(TrainValidationSplitReader,self).__init__()self.cls=clsdefload(self,path:str)->"TrainValidationSplit":metadata=DefaultParamsReader.loadMetadata(path,self.sc)ifnotDefaultParamsReader.isPythonParamsInstance(metadata):returnJavaMLReader(self.cls).load(path)# type: ignore[arg-type]else:metadata,estimator,evaluator,estimatorParamMaps=_ValidatorSharedReadWrite.load(path,self.sc,metadata)tvs=TrainValidationSplit(estimator=estimator,estimatorParamMaps=estimatorParamMaps,evaluator=evaluator)tvs=tvs._resetUid(metadata["uid"])DefaultParamsReader.getAndSetParams(tvs,metadata,skipParams=["estimatorParamMaps"])returntvs@inherit_docclassTrainValidationSplitWriter(MLWriter):def__init__(self,instance:"TrainValidationSplit"):super(TrainValidationSplitWriter,self).__init__()self.instance=instancedefsaveImpl(self,path:str)->None:_ValidatorSharedReadWrite.validateParams(self.instance)_ValidatorSharedReadWrite.saveImpl(path,self.instance,self.sc)@inherit_docclassTrainValidationSplitModelReader(MLReader["TrainValidationSplitModel"]):def__init__(self,cls:Type["TrainValidationSplitModel"]):super(TrainValidationSplitModelReader,self).__init__()self.cls=clsdefload(self,path:str)->"TrainValidationSplitModel":metadata=DefaultParamsReader.loadMetadata(path,self.sc)ifnotDefaultParamsReader.isPythonParamsInstance(metadata):returnJavaMLReader(self.cls).load(path)# type: ignore[arg-type]else:metadata,estimator,evaluator,estimatorParamMaps=_ValidatorSharedReadWrite.load(path,self.sc,metadata)bestModelPath=os.path.join(path,"bestModel")bestModel:Model=DefaultParamsReader.loadParamsInstance(bestModelPath,self.sc)validationMetrics=metadata["validationMetrics"]persistSubModels=("persistSubModels"inmetadata)andmetadata["persistSubModels"]ifpersistSubModels:subModels=[None]*len(estimatorParamMaps)forparamIndexinrange(len(estimatorParamMaps)):modelPath=os.path.join(path,"subModels",f"{paramIndex}")subModels[paramIndex]=DefaultParamsReader.loadParamsInstance(modelPath,self.sc)else:subModels=NonetvsModel=TrainValidationSplitModel(bestModel,validationMetrics=validationMetrics,subModels=cast(Optional[List[Model]],subModels),)tvsModel=tvsModel._resetUid(metadata["uid"])tvsModel.set(tvsModel.estimator,estimator)tvsModel.set(tvsModel.estimatorParamMaps,estimatorParamMaps)tvsModel.set(tvsModel.evaluator,evaluator)DefaultParamsReader.getAndSetParams(tvsModel,metadata,skipParams=["estimatorParamMaps"])returntvsModel@inherit_docclassTrainValidationSplitModelWriter(MLWriter):def__init__(self,instance:"TrainValidationSplitModel"):super(TrainValidationSplitModelWriter,self).__init__()self.instance=instancedefsaveImpl(self,path:str)->None:_ValidatorSharedReadWrite.validateParams(self.instance)instance=self.instancepersistSubModels=_ValidatorSharedReadWrite.getValidatorModelWriterPersistSubModelsParam(self)extraMetadata={"validationMetrics":instance.validationMetrics,"persistSubModels":persistSubModels,}_ValidatorSharedReadWrite.saveImpl(path,instance,self.sc,extraMetadata=extraMetadata)bestModelPath=os.path.join(path,"bestModel")cast(MLWritable,instance.bestModel).save(bestModelPath)ifpersistSubModels:ifinstance.subModelsisNone:raiseValueError(_save_with_persist_submodels_no_submodels_found_err)subModelsPath=os.path.join(path,"subModels")forparamIndexinrange(len(instance.getEstimatorParamMaps())):modelPath=os.path.join(subModelsPath,f"{paramIndex}")cast(MLWritable,instance.subModels[paramIndex]).save(modelPath)class_TrainValidationSplitParams(_ValidatorParams):""" Params for :py:class:`TrainValidationSplit` and :py:class:`TrainValidationSplitModel`. .. versionadded:: 3.0.0 """trainRatio:Param[float]=Param(Params._dummy(),"trainRatio","Param for ratio between train and\ validation data. Must be between 0 and 1.",typeConverter=TypeConverters.toFloat,)def__init__(self,*args:Any):super(_TrainValidationSplitParams,self).__init__(*args)self._setDefault(trainRatio=0.75)@since("2.0.0")defgetTrainRatio(self)->float:""" Gets the value of trainRatio or its default value. """returnself.getOrDefault(self.trainRatio)
[docs]classTrainValidationSplit(Estimator["TrainValidationSplitModel"],_TrainValidationSplitParams,HasParallelism,HasCollectSubModels,MLReadable["TrainValidationSplit"],MLWritable,):""" Validation for hyper-parameter tuning. Randomly splits the input dataset into train and validation sets, and uses evaluation metric on the validation set to select the best model. Similar to :class:`CrossValidator`, but only splits the set once. .. versionadded:: 2.0.0 Examples -------- >>> from pyspark.ml.classification import LogisticRegression >>> from pyspark.ml.evaluation import BinaryClassificationEvaluator >>> from pyspark.ml.linalg import Vectors >>> from pyspark.ml.tuning import TrainValidationSplit, ParamGridBuilder >>> from pyspark.ml.tuning import TrainValidationSplitModel >>> import tempfile >>> dataset = spark.createDataFrame( ... [(Vectors.dense([0.0]), 0.0), ... (Vectors.dense([0.4]), 1.0), ... (Vectors.dense([0.5]), 0.0), ... (Vectors.dense([0.6]), 1.0), ... (Vectors.dense([1.0]), 1.0)] * 10, ... ["features", "label"]).repartition(1) >>> lr = LogisticRegression() >>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() >>> evaluator = BinaryClassificationEvaluator() >>> tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator, ... parallelism=1, seed=42) >>> tvsModel = tvs.fit(dataset) >>> tvsModel.getTrainRatio() 0.75 >>> tvsModel.validationMetrics [0.5, ... >>> path = tempfile.mkdtemp() >>> model_path = path + "/model" >>> tvsModel.write().save(model_path) >>> tvsModelRead = TrainValidationSplitModel.read().load(model_path) >>> tvsModelRead.validationMetrics [0.5, ... >>> evaluator.evaluate(tvsModel.transform(dataset)) 0.833... >>> evaluator.evaluate(tvsModelRead.transform(dataset)) 0.833... """_input_kwargs:Dict[str,Any]@keyword_onlydef__init__(self,*,estimator:Optional[Estimator]=None,estimatorParamMaps:Optional[List["ParamMap"]]=None,evaluator:Optional[Evaluator]=None,trainRatio:float=0.75,parallelism:int=1,collectSubModels:bool=False,seed:Optional[int]=None,)->None:""" __init__(self, \\*, estimator=None, estimatorParamMaps=None, evaluator=None, \ trainRatio=0.75, parallelism=1, collectSubModels=False, seed=None) """super(TrainValidationSplit,self).__init__()self._setDefault(parallelism=1)kwargs=self._input_kwargsself._set(**kwargs)
[docs]@since("2.0.0")@keyword_onlydefsetParams(self,*,estimator:Optional[Estimator]=None,estimatorParamMaps:Optional[List["ParamMap"]]=None,evaluator:Optional[Evaluator]=None,trainRatio:float=0.75,parallelism:int=1,collectSubModels:bool=False,seed:Optional[int]=None,)->"TrainValidationSplit":""" setParams(self, \\*, estimator=None, estimatorParamMaps=None, evaluator=None, \ trainRatio=0.75, parallelism=1, collectSubModels=False, seed=None): Sets params for the train validation split. """kwargs=self._input_kwargsreturnself._set(**kwargs)
[docs]@since("2.0.0")defsetEstimator(self,value:Estimator)->"TrainValidationSplit":""" Sets the value of :py:attr:`estimator`. """returnself._set(estimator=value)
[docs]@since("2.0.0")defsetEstimatorParamMaps(self,value:List["ParamMap"])->"TrainValidationSplit":""" Sets the value of :py:attr:`estimatorParamMaps`. """returnself._set(estimatorParamMaps=value)
[docs]@since("2.0.0")defsetEvaluator(self,value:Evaluator)->"TrainValidationSplit":""" Sets the value of :py:attr:`evaluator`. """returnself._set(evaluator=value)
[docs]@since("2.0.0")defsetTrainRatio(self,value:float)->"TrainValidationSplit":""" Sets the value of :py:attr:`trainRatio`. """returnself._set(trainRatio=value)
[docs]defsetSeed(self,value:int)->"TrainValidationSplit":""" Sets the value of :py:attr:`seed`. """returnself._set(seed=value)
[docs]defsetParallelism(self,value:int)->"TrainValidationSplit":""" Sets the value of :py:attr:`parallelism`. """returnself._set(parallelism=value)
[docs]defsetCollectSubModels(self,value:bool)->"TrainValidationSplit":""" Sets the value of :py:attr:`collectSubModels`. """returnself._set(collectSubModels=value)
[docs]defcopy(self,extra:Optional["ParamMap"]=None)->"TrainValidationSplit":""" Creates a copy of this instance with a randomly generated uid and some extra params. This copies creates a deep copy of the embedded paramMap, and copies the embedded and extra parameters over. .. versionadded:: 2.0.0 Parameters ---------- extra : dict, optional Extra parameters to copy to the new instance Returns ------- :py:class:`TrainValidationSplit` Copy of this instance """ifextraisNone:extra=dict()newTVS=Params.copy(self,extra)ifself.isSet(self.estimator):newTVS.setEstimator(self.getEstimator().copy(extra))# estimatorParamMaps remain the sameifself.isSet(self.evaluator):newTVS.setEvaluator(self.getEvaluator().copy(extra))returnnewTVS
[docs]@since("2.3.0")defwrite(self)->MLWriter:"""Returns an MLWriter instance for this ML instance."""if_ValidatorSharedReadWrite.is_java_convertible(self):returnJavaMLWriter(self)# type: ignore[arg-type]returnTrainValidationSplitWriter(self)
[docs]@classmethod@since("2.3.0")defread(cls)->TrainValidationSplitReader:"""Returns an MLReader instance for this class."""returnTrainValidationSplitReader(cls)
@classmethoddef_from_java(cls,java_stage:"JavaObject")->"TrainValidationSplit":""" Given a Java TrainValidationSplit, create and return a Python wrapper of it. Used for ML persistence. """estimator,epms,evaluator=super(TrainValidationSplit,cls)._from_java_impl(java_stage)trainRatio=java_stage.getTrainRatio()seed=java_stage.getSeed()parallelism=java_stage.getParallelism()collectSubModels=java_stage.getCollectSubModels()# Create a new instance of this stage.py_stage=cls(estimator=estimator,estimatorParamMaps=epms,evaluator=evaluator,trainRatio=trainRatio,seed=seed,parallelism=parallelism,collectSubModels=collectSubModels,)py_stage._resetUid(java_stage.uid())returnpy_stagedef_to_java(self)->"JavaObject":""" Transfer this instance to a Java TrainValidationSplit. Used for ML persistence. Returns ------- py4j.java_gateway.JavaObject Java object equivalent to this instance. """estimator,epms,evaluator=super(TrainValidationSplit,self)._to_java_impl()_java_obj=JavaParams._new_java_obj("org.apache.spark.ml.tuning.TrainValidationSplit",self.uid)_java_obj.setEstimatorParamMaps(epms)_java_obj.setEvaluator(evaluator)_java_obj.setEstimator(estimator)_java_obj.setTrainRatio(self.getTrainRatio())_java_obj.setSeed(self.getSeed())_java_obj.setParallelism(self.getParallelism())_java_obj.setCollectSubModels(self.getCollectSubModels())return_java_obj
[docs]classTrainValidationSplitModel(Model,_TrainValidationSplitParams,MLReadable["TrainValidationSplitModel"],MLWritable):""" Model from train validation split. .. versionadded:: 2.0.0 """def__init__(self,bestModel:Model,validationMetrics:Optional[List[float]]=None,subModels:Optional[List[Model]]=None,):super(TrainValidationSplitModel,self).__init__()#: best model from train validation splitself.bestModel=bestModel#: evaluated validation metricsself.validationMetrics=validationMetricsor[]#: sub models from train validation splitself.subModels=subModelsdef_transform(self,dataset:DataFrame)->DataFrame:returnself.bestModel.transform(dataset)
[docs]defcopy(self,extra:Optional["ParamMap"]=None)->"TrainValidationSplitModel":""" Creates a copy of this instance with a randomly generated uid and some extra params. This copies the underlying bestModel, creates a deep copy of the embedded paramMap, and copies the embedded and extra parameters over. And, this creates a shallow copy of the validationMetrics. It does not copy the extra Params into the subModels. .. versionadded:: 2.0.0 Parameters ---------- extra : dict, optional Extra parameters to copy to the new instance Returns ------- :py:class:`TrainValidationSplitModel` Copy of this instance """ifextraisNone:extra=dict()bestModel=self.bestModel.copy(extra)validationMetrics=list(self.validationMetrics)assertself.subModelsisnotNonesubModels=[model.copy()formodelinself.subModels]returnself._copyValues(TrainValidationSplitModel(bestModel,validationMetrics,subModels),extra=extra)
[docs]@since("2.3.0")defwrite(self)->MLWriter:"""Returns an MLWriter instance for this ML instance."""if_ValidatorSharedReadWrite.is_java_convertible(self):returnJavaMLWriter(self)# type: ignore[arg-type]returnTrainValidationSplitModelWriter(self)
[docs]@classmethod@since("2.3.0")defread(cls)->TrainValidationSplitModelReader:"""Returns an MLReader instance for this class."""returnTrainValidationSplitModelReader(cls)
@classmethoddef_from_java(cls,java_stage:"JavaObject")->"TrainValidationSplitModel":""" Given a Java TrainValidationSplitModel, create and return a Python wrapper of it. Used for ML persistence. """# Load information from java_stage to the instance.sc=SparkContext._active_spark_contextassertscisnotNonebestModel:Model=JavaParams._from_java(java_stage.bestModel())validationMetrics=_java2py(sc,java_stage.validationMetrics())estimator,epms,evaluator=super(TrainValidationSplitModel,cls)._from_java_impl(java_stage)# Create a new instance of this stage.py_stage=cls(bestModel=bestModel,validationMetrics=validationMetrics)params={"evaluator":evaluator,"estimator":estimator,"estimatorParamMaps":epms,"trainRatio":java_stage.getTrainRatio(),"seed":java_stage.getSeed(),}forparam_name,param_valinparams.items():py_stage=py_stage._set(**{param_name:param_val})ifjava_stage.hasSubModels():py_stage.subModels=[JavaParams._from_java(sub_model)forsub_modelinjava_stage.subModels()]py_stage._resetUid(java_stage.uid())returnpy_stagedef_to_java(self)->"JavaObject":""" Transfer this instance to a Java TrainValidationSplitModel. Used for ML persistence. Returns ------- py4j.java_gateway.JavaObject Java object equivalent to this instance. """sc=SparkContext._active_spark_contextassertscisnotNone_java_obj=JavaParams._new_java_obj("org.apache.spark.ml.tuning.TrainValidationSplitModel",self.uid,cast(JavaParams,self.bestModel)._to_java(),_py2java(sc,self.validationMetrics),)estimator,epms,evaluator=super(TrainValidationSplitModel,self)._to_java_impl()params={"evaluator":evaluator,"estimator":estimator,"estimatorParamMaps":epms,"trainRatio":self.getTrainRatio(),"seed":self.getSeed(),}forparam_name,param_valinparams.items():java_param=_java_obj.getParam(param_name)pair=java_param.w(param_val)_java_obj.set(pair)ifself.subModelsisnotNone:java_sub_models=[cast(JavaParams,sub_model)._to_java()forsub_modelinself.subModels]_java_obj.setSubModels(java_sub_models)return_java_obj
if__name__=="__main__":importdoctestfrompyspark.sqlimportSparkSessionglobs=globals().copy()# The small batch size here ensures that we see multiple batches,# even in these small test examples:spark=SparkSession.builder.master("local[2]").appName("ml.tuning tests").getOrCreate()sc=spark.sparkContextglobs["sc"]=scglobs["spark"]=spark(failure_count,test_count)=doctest.testmod(globs=globs,optionflags=doctest.ELLIPSIS)spark.stop()iffailure_count:sys.exit(-1)