Last active
December 23, 2023 03:40
-
-
Save Codegass/ee5c3484dfc1ea0196d27651d338ba39 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| package com.envestnet.aaacli.core; | |
| import java.io.*; | |
| import java.nio.file.*; | |
| import java.util.logging.*; | |
| import java.util.stream.Collectors; | |
| public class MLRunner { | |
| private final File tempScriptFolder; | |
| private final Logger logger; | |
| public final File logFile; | |
| public MLRunner() throws IOException { | |
| // Prepare a temporary directory for Python scripts and logs | |
| tempScriptFolder = Files.createTempDirectory("python_scripts").toFile(); | |
| tempScriptFolder.deleteOnExit(); | |
| // Initialize logging to a file in the temp directory | |
| logFile = new File(tempScriptFolder, "MLRunner.log"); | |
| logger = Logger.getLogger(MLRunner.class.getName()); | |
| FileHandler fileHandler = new FileHandler(logFile.getAbsolutePath()); | |
| logger.addHandler(fileHandler); | |
| SimpleFormatter formatter = new SimpleFormatter(); | |
| fileHandler.setFormatter(formatter); | |
| logger.setLevel(Level.ALL); | |
| logger.info("Temporary directory for Python scripts and logs: " + tempScriptFolder.getAbsolutePath()); | |
| } | |
| public void installRequirements() throws IOException, InterruptedException { | |
| logger.info("Installing Python requirements from requirements.txt..."); | |
| // Define the path to the requirements.txt file in the model directory | |
| Path requirementsPath = Paths.get("model/requirements.txt"); | |
| File requirementsFile = new File(tempScriptFolder, "requirements.txt"); | |
| Files.copy(requirementsPath, requirementsFile.toPath(), StandardCopyOption.REPLACE_EXISTING); | |
| // Install Python dependencies | |
| ProcessBuilder processBuilder = new ProcessBuilder("pip3", "install", "-r", requirementsFile.getAbsolutePath()); | |
| processBuilder.directory(tempScriptFolder); | |
| Process process = processBuilder.start(); | |
| process.waitFor(); | |
| } | |
| public void extractPythonScripts() throws IOException { | |
| logger.info("Extracting Python scripts..."); | |
| // Define the path to the model/python directory | |
| Path modelPythonPath = Paths.get("model/python"); | |
| String[] scripts = { | |
| "predict.py", | |
| // ... other script names ... | |
| }; | |
| for (String script : scripts) { | |
| Path scriptPath = modelPythonPath.resolve(script); | |
| if (!Files.exists(scriptPath)) { | |
| logger.severe("Script not found: " + scriptPath); | |
| continue; | |
| } | |
| File scriptFile = new File(tempScriptFolder, script); | |
| scriptFile.getParentFile().mkdirs(); // Ensure the parent directories exist | |
| Files.copy(scriptPath, scriptFile.toPath(), StandardCopyOption.REPLACE_EXISTING); | |
| logger.info("Extracted script: " + script); | |
| } | |
| } | |
| public void runScript(String scriptName) throws IOException, InterruptedException { | |
| logger.info("Running Python script: " + scriptName); | |
| File scriptFile = new File(tempScriptFolder, scriptName); | |
| ProcessBuilder processBuilder = new ProcessBuilder("python3", scriptFile.getAbsolutePath()); | |
| processBuilder.directory(tempScriptFolder); | |
| processBuilder.redirectErrorStream(true); | |
| Process process = processBuilder.start(); | |
| try (BufferedReader reader = new BufferedReader(new InputStreamReader(process.getInputStream()))) { | |
| String line; | |
| while ((line = reader.readLine()) != null) { | |
| logger.info(line); | |
| } | |
| } | |
| int exitCode = process.waitFor(); | |
| logger.info("Script exited with code : " + exitCode); | |
| } | |
| public static void main(String[] args) { | |
| try { | |
| MLRunner executor = new MLRunner(); | |
| System.out.println("Log file location: " + executor.logFile.getAbsolutePath()); | |
| executor.checkPythonVersion(); | |
| executor.extractPythonScripts(); | |
| executor.installRequirements(); | |
| executor.runScript("predict.py"); | |
| } catch (IOException | InterruptedException e) { | |
| Logger.getLogger(MLRunner.class.getName()).log(Level.SEVERE, null, e); | |
| } | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment