Skip to content
This repository has been archived by the owner on Apr 27, 2023. It is now read-only.

Bump tensorflow from 1.13.1 to 2.7.2 in /webapp #375

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions megnet/data/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def get_flat_data(graphs: list[dict], targets: list | None = None) -> tuple:
return tuple(output)

@staticmethod
def _get_dummy_converter() -> "DummyConverter":
def _get_dummy_converter() -> DummyConverter:
return DummyConverter()

def as_dict(self) -> dict:
Expand All @@ -207,7 +207,7 @@ def as_dict(self) -> dict:
return all_dict

@classmethod
def from_dict(cls, d: dict) -> "StructureGraph":
def from_dict(cls, d: dict) -> StructureGraph:
"""
Initialization from dictionary
Args:
Expand Down Expand Up @@ -258,7 +258,7 @@ def convert(self, structure: Structure, state_attributes: list | None = None) ->
return {"atom": atoms, "bond": bonds, "state": state_attributes, "index1": index1, "index2": index2}

@classmethod
def from_structure_graph(cls, structure_graph: StructureGraph) -> "StructureGraphFixedRadius":
def from_structure_graph(cls, structure_graph: StructureGraph) -> StructureGraphFixedRadius:
"""
Initialize from pymatgen StructureGraph
Args:
Expand Down
7 changes: 4 additions & 3 deletions megnet/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Implements basic GraphModels.
"""
from __future__ import annotations

import os
from warnings import warn

Expand Down Expand Up @@ -73,7 +74,7 @@ def train(
patience: int = 500,
dirname: str = "callback",
**kwargs,
) -> "GraphModel":
) -> GraphModel:
"""
Args:
train_structures: (list) list of pymatgen structures
Expand Down Expand Up @@ -136,7 +137,7 @@ def train_from_graphs(
save_checkpoint: bool = True,
dirname: str = "callback",
**kwargs,
) -> "GraphModel":
) -> GraphModel:
"""
Args:
train_graphs: (list) list of graph dictionaries
Expand Down Expand Up @@ -375,7 +376,7 @@ def save_model(self, filename: str) -> None:
)

@classmethod
def from_file(cls, filename: str) -> "GraphModel":
def from_file(cls, filename: str) -> GraphModel:
"""
Class method to load model from
filename for keras model
Expand Down
1 change: 1 addition & 0 deletions megnet/models/megnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Implements megnet models.
"""
from __future__ import annotations

from typing import Callable

import numpy as np
Expand Down
1 change: 1 addition & 0 deletions megnet/utils/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
pretrained megnet model
"""
from __future__ import annotations

import os

import numpy as np
Expand Down
125 changes: 65 additions & 60 deletions webapp/app.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from flask import Flask, request, render_template, make_response
from flask.json import jsonify
from megnet.models import MEGNetModel
from pymatgen import MPRester, Structure
import tensorflow as tf
import glob
import os
import re
from functools import lru_cache

import tensorflow as tf
from flask import Flask, make_response, render_template, request
from flask.json import jsonify
from pymatgen import MPRester, Structure

from megnet.models import MEGNetModel

app = Flask(__name__)
mpr = MPRester(os.environ.get("MAPI_KEY"))
Expand Down Expand Up @@ -36,7 +37,7 @@ def predict(model_name, structure):
model = models[model_name]
with graph.as_default():
return model.predict_structure(structure).ravel()
except Exception as ex:
except Exception:
return float("nan")


Expand All @@ -52,10 +53,18 @@ def get_results(structure):

@lru_cache(maxsize=64)
def get_mp_results(mp_id):
data = mpr.query({"task_id": mp_id},
properties=["structure", "formation_energy_per_atom",
"band_gap", "efermi", "elasticity.K_VRH", "elasticity.G_VRH"],
mp_decode=False)
data = mpr.query(
{"task_id": mp_id},
properties=[
"structure",
"formation_energy_per_atom",
"band_gap",
"efermi",
"elasticity.K_VRH",
"elasticity.G_VRH",
],
mp_decode=False,
)
data = data[0]
s = Structure.from_dict(data["structure"])
formula = s.composition.reduced_formula
Expand All @@ -68,36 +77,32 @@ def get_mp_results(mp_id):
return formula, results


@app.route('/')
@app.route("/")
def index():
return make_response(render_template('index.html'))
return make_response(render_template("index.html"))


@app.route('/models')
@app.route("/models")
def get_models():
return jsonify(list(models.keys()))


@app.route('/query', methods=['GET'])
@app.route("/query", methods=["GET"])
def query():
message = ""
try:
mp_id = request.args.get("mp_id")
formula, results = get_mp_results(mp_id)
except Exception as ex:
except Exception:
message = "Please check your Materials Project Id."
formula = ""
results = []
return make_response(render_template(
'index.html',
mp_id=mp_id,
formula=html_formula(formula),
message=message,
mp_results=results
))
return make_response(
render_template("index.html", mp_id=mp_id, formula=html_formula(formula), message=message, mp_results=results)
)


@app.route('/query_structure', methods=['POST'])
@app.route("/query_structure", methods=["POST"])
def query_structure():
formula = ""
message = ""
Expand All @@ -116,58 +121,50 @@ def query_structure():
results = get_results(s)
except Exception as ex:
message = "Error reading structure! %s" % (str(ex))
return make_response(render_template(
'index.html',
structure_string=structure_string,
formula=html_formula(formula),
structure_results=results,
message=message
))


@app.route('/rest/predict_structure/<string:model_name>', methods=['POST'])
return make_response(
render_template(
"index.html",
structure_string=structure_string,
formula=html_formula(formula),
structure_results=results,
message=message,
)
)


@app.route("/rest/predict_structure/<string:model_name>", methods=["POST"])
def predict_structure_rest(model_name):
try:
structure = Structure.from_str(request.form["structure"],
fmt=request.form["fmt"])
structure = Structure.from_str(request.form["structure"], fmt=request.form["fmt"])
val = predict(model_name, structure)
d = {
"model": model_name,
"val": float(val),
"formula": structure.composition.reduced_formula
}
d = {"model": model_name, "val": float(val), "formula": structure.composition.reduced_formula}
return jsonify(d)
except Exception as ex:
return jsonify(str(ex))


@app.route('/rest/predict_mp/<string:model_name>/<string:mp_id>')
@app.route("/rest/predict_mp/<string:model_name>/<string:mp_id>")
def predict_mp_rest(model_name, mp_id):
try:
structure = mpr.get_structure_by_material_id(mp_id)
val = predict(model_name, structure)
d = {
"model": model_name,
"val": float(val),
"formula": structure.composition.reduced_formula
}
d = {"model": model_name, "val": float(val), "formula": structure.composition.reduced_formula}
return jsonify(d)
except Exception as ex:
return jsonify(str(ex))


@app.route('/predict_structure/<string:model_name>', methods=['POST'])
@app.route("/predict_structure/<string:model_name>", methods=["POST"])
def predict_structure(model_name):
try:
structure = Structure.from_str(request.form["structure"],
fmt=request.form["fmt"])
structure = Structure.from_str(request.form["structure"], fmt=request.form["fmt"])
val = predict(model_name, structure)
return jsonify(float(val))
except Exception as ex:
return jsonify(str(ex))


@app.route('/predict_mp/<string:model_name>/<string:mp_id>')
@app.route("/predict_mp/<string:model_name>/<string:mp_id>")
def predict_mp(model_name, mp_id):
try:
structure = mpr.get_structure_by_material_id(mp_id)
Expand All @@ -182,20 +179,28 @@ def main():
import argparse

parser = argparse.ArgumentParser(
description="""Basic web app for MEGNet.""",
epilog="Authors: Chi Chen, Shyue Ping Ong")
description="""Basic web app for MEGNet.""", epilog="Authors: Chi Chen, Shyue Ping Ong"
)

parser.add_argument("-d", "--debug", dest="debug", action="store_true", help="Whether to run in debug mode.")
parser.add_argument(
"-d", "--debug", dest="debug", action="store_true",
help="Whether to run in debug mode.")
parser.add_argument(
"-hh", "--host", dest="host", type=str, nargs="?",
default='0.0.0.0',
help="Host in which to run the server. Defaults to 0.0.0.0.")
"-hh",
"--host",
dest="host",
type=str,
nargs="?",
default="0.0.0.0",
help="Host in which to run the server. Defaults to 0.0.0.0.",
)
parser.add_argument(
"-p", "--port", dest="port", type=int, nargs="?",
"-p",
"--port",
dest="port",
type=int,
nargs="?",
default=5000,
help="Port in which to run the server. Defaults to 5000.")
help="Port in which to run the server. Defaults to 5000.",
)

args = parser.parse_args()
init()
Expand Down
2 changes: 1 addition & 1 deletion webapp/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
pymatgen==2019.7.30
tensorflow==1.13.1
tensorflow==2.7.2
keras==2.2.4
megnet==0.3.5
gunicorn==19.9.0
Expand Down
4 changes: 2 additions & 2 deletions webapp/static/index.css
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

#profile{
text-align: center;
font-size: 1.5em;
font-size: 1.5em;
}

input{
Expand Down Expand Up @@ -59,4 +59,4 @@ table {
textarea{
font-size: 0.8em;
margin: auto;
}
}
2 changes: 1 addition & 1 deletion webapp/static/style.css
Original file line number Diff line number Diff line change
Expand Up @@ -107,4 +107,4 @@ h1, h2 {
.tagline {
font-size: 1.2em;
text-align: center;
}
}