Skip to content

Commit

Permalink
Merge pull request #746 from I2PC/co_deepCenter
Browse files Browse the repository at this point in the history
Co deep center
  • Loading branch information
albertmena authored Dec 4, 2023
2 parents 067fcb0 + 4202b62 commit 87a1939
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 24 deletions.
2 changes: 2 additions & 0 deletions xmipp3/protocols.conf
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ Protocols SPA = [
{"tag": "section", "text": "2D", "children": [
{"tag": "protocol_group", "text": "Align", "openItem": "False", "children": [
{"tag": "protocol", "value": "XmippProtCL2DAlign", "text": "default"},
{"tag": "protocol", "value": "XmippProtDeepCenter", "text": "default"},
{"tag": "section", "text": "more", "openItem": "False", "children": [
{"tag": "protocol", "value": "ProtAlignmentAssign", "text": "default"}
]}
Expand Down Expand Up @@ -98,6 +99,7 @@ Protocols SPA = [
{"tag": "protocol_group", "text": "Classify", "openItem": "False", "children": []},
{"tag": "protocol_group", "text": "Refine", "openItem": "False", "children": [
{"tag": "protocol", "value": "XmippProtReconstructHighRes", "text": "default"},
{"tag": "protocol", "value": "XmippProtDeepGlobalAssignment", "text": "default"},
{"tag": "protocol", "value": "XmippProtProjMatch", "text": "default"},
{"tag": "protocol", "value": "XmippProtLocalCTF", "text": "default"},
{"tag": "section", "text": "more", "openItem": "False", "children": []}
Expand Down
2 changes: 1 addition & 1 deletion xmipp3/protocols/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from .protocol_ctf_correct_wiener2d import XmippProtCTFCorrectWiener2D
from .protocol_consensus_classes import XmippProtConsensusClasses
from .protocol_denoise_particles import XmippProtDenoiseParticles
from .protocol_deep_micrograph_screen import XmippProtDeepMicrographScreen
from .protocol_deep_micrograph_screen import XmippProtDeepMicrographScreen
from .protocol_eliminate_empty_images import (XmippProtEliminateEmptyParticles,
XmippProtEliminateEmptyClasses)
from .protocol_enrich import XmippProtEnrich
Expand Down
56 changes: 33 additions & 23 deletions xmipp3/protocols/protocol_deep_center_assignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(self, **args):

# --------------------------- DEFINE param functions --------------------------------------------
def _defineParams(self, form):
form.addParallelSection(threads=1, mpi=1)
form.addParallelSection(threads=1, mpi=4)

form.addHidden(GPU_LIST, StringParam, default='0',
expertLevel=LEVEL_ADVANCED,
Expand All @@ -65,7 +65,7 @@ def _defineParams(self, form):
help='The set of particles to predict')

form.addParam('trainModels', BooleanParam, label="Train models",
pointerClass='SetOfParticles', default=False,
pointerClass='SetOfParticles', default=True,
help='Choose if you want to train a model using a centered set of particles')

trainingGroup = form.addGroup('Training parameters', condition='trainModels==True')
Expand All @@ -77,11 +77,14 @@ def _defineParams(self, form):
trainingGroup.addParam('inputTrainSet', PointerParam, label="Input train set",
pointerClass='SetOfParticles', allowsNull=True,
pointerCondition='hasAlignment2D or hasAlignmentProj',
help='The set of particles to train the models')
help='The set of particles to train the models. If empty, the input image set is taken.')

trainingGroup.addParam('trainSetSize', IntParam, label="Train set size", default=5000,
help='How many particles from the training')

trainingGroup.addParam('numEpochs', IntParam,
label="Number of epochs",
default=25,
default=10,
expertLevel=LEVEL_ADVANCED,
help="Number of epochs for training.")

Expand Down Expand Up @@ -146,9 +149,12 @@ def insertTrainSteps(self):
self._insertFunctionStep("train", numGPU[0])

# --------------------------- STEPS functions ---------------------------------------------------
def convertStep(self, inputSet):
def convertStep(self, inputSet, train=False):
self.Xdim = 128
writeSetOfParticles(inputSet, self.predictImgsFn)
if train:
self.runJob("xmipp_metadata_utilities","-i %s --operate random_subset %d"%\
(self.predictImgsFn,self.trainSetSize), numberOfMpi=1)
self.runJob("xmipp_image_resize",
"-i %s -o %s --save_metadata_stack %s --fourier %d" %
(self.predictImgsFn,
Expand All @@ -157,36 +163,39 @@ def convertStep(self, inputSet):
self.Xdim), numberOfMpi=self.numberOfThreads.get() * self.numberOfMpi.get())

def convertTrainStep(self):
self.convertStep(self.inputTrainSet.get())
if self.inputTrainSet.get() is None:
self.convertStep(self.inputImageSet.get(), train=True)
else:
self.convertStep(self.inputTrainSet.get(), train=True)

def train(self, gpuId, mode="", orderSymmetry=None):
def train(self, gpuId, mode="", symmetry=None):
args = "%s %s %f %d %d %s %d %f %d %d" % (
self._getExtraPath(f"{self._trainingResizedFileName}.xmd"), self._getExtraPath("model"), self.sigma.get(),
self.numEpochs, self.batchSize.get(), gpuId, self.numModels.get(), self.learningRate.get(),
self.patience.get(), 0)
if orderSymmetry:
args += " " + str(orderSymmetry)
if symmetry:
args += " " + str(symmetry)
self.runJob(f"xmipp_deep_{mode}", args, numberOfMpi=1, env=self.getCondaEnv())

remove(self._getExtraPath(f"{self._trainingResizedFileName}.xmd"))
remove(self._getExtraPath(f"{self._trainingResizedFileName}.stk"))

def predict(self, predictImgsFn, gpuId, mode="", inputModel="", trainedModel=True, orderSymmetry=None):
def predict(self, predictImgsFn, gpuId, mode="", inputModel="", trainedModel=True, symmetry=None):
if mode != "center" or trainedModel:
args = "%s %s %s %s %s %d %d %d" % (
self._getExtraPath(f"{self._trainingResizedFileName}.xmd"), gpuId, self._getPath(), predictImgsFn,
inputModel, self.numModels.get(), self.tolerance.get(),
self.maxModels.get())
if orderSymmetry:
args += " " + str(orderSymmetry)
if symmetry:
args += " " + str(symmetry)
else:
args = "%s %s %s %s %s %d %d %d" % (
self._getExtraPath(f"{self._trainingResizedFileName}.xmd"), gpuId, self._getPath(), predictImgsFn,
os.path.join(XmippScript.getModel("deep_center"), 'modelCenter'), self.numModels.get(), self.tolerance.get(),
self.maxModels.get())
self.runJob(f"xmipp_deep_{mode}_predict", args, numberOfMpi=1, env=self.getCondaEnv())
remove(self._getExtraPath(f"{self._trainingResizedFileName}.xmd"))
remove(self._getExtraPath(f"{self._trainingResizedFileName}.stk"))
#remove(self._getExtraPath(f"{self._trainingResizedFileName}.xmd"))
#remove(self._getExtraPath(f"{self._trainingResizedFileName}.stk"))

def createOutputStep(self):
imgFname = self._getPath('predict_results.xmd')
Expand Down Expand Up @@ -217,10 +226,10 @@ def _insertAllSteps(self):
super()._insertAllSteps()

# --------------------------- STEPS functions ---------------------------------------------------
def train(self, gpuId, mode="", orderSymmetry=None):
def train(self, gpuId, mode="", symmetry=None):
super().train(gpuId, mode="center")

def predict(self, predictImgsFn, gpuId, mode="", inputModel="", trainedModel=True, orderSymmetry=None):
def predict(self, predictImgsFn, gpuId, mode="", inputModel="", trainedModel=True, symmetry=None):
super().predict(gpuId, predictImgsFn, mode="center",
inputModel=self._getExtraPath("model"),
trainedModel=self.trainModels.get())
Expand All @@ -244,18 +253,19 @@ def _defineParams(self, form):
trainingGroup.condition = String('True')

section = form.getSection(label=Message.LABEL_INPUT)
section.addParam('orderSymmetry', IntParam,
label="Order of symmetry", default=1,
help="Order of the group of the molecule.")
section.addParam('symmetry', StringParam,
label="Symmetry", default="c1",
help="Symmetry of the molecule")

# --------------------------- INSERT steps functions --------------------------------------------
def _insertAllSteps(self):
self.insertTrainSteps()
super()._insertAllSteps()

# --------------------------- STEPS functions ---------------------------------------------------
def train(self, gpuId, mode="", orderSymmetry=None):
super().train(gpuId, mode="global_assignment", orderSymmetry=self.orderSymmetry.get())
def train(self, gpuId, mode="", symmetry=None):
super().train(gpuId, mode="global_assignment", symmetry=self.symmetry.get())

def predict(self, predictImgsFn, gpuId, mode="", inputModel="", trainedModel=True, orderSymmetry=None):
super().predict(gpuId, predictImgsFn, mode="global_assignment", inputModel=self._getExtraPath("model"), orderSymmetry=self.orderSymmetry.get())
def predict(self, predictImgsFn, gpuId, mode="", inputModel="", trainedModel=True, symmetry=None):
super().predict(gpuId, predictImgsFn, mode="global_assignment", inputModel=self._getExtraPath("model"),
symmetry=self.symmetry.get())

0 comments on commit 87a1939

Please sign in to comment.