import argparse import asyncio import base64 import binascii import queue import re import shutil import ssl import sys import threading import socket import time import traceback from pathlib import Path from prompt_toolkit import Application from prompt_toolkit.layout import Layout, HSplit from prompt_toolkit.lexers import Lexer from prompt_toolkit.widgets import TextArea from prompt_toolkit.key_binding import KeyBindings from prompt_toolkit.styles import Style from prompt_toolkit.filters import has_focus from prompt_toolkit.shortcuts import clear as ptk_clear version = "Beta-1" SERVER_JAR_DIR = Path.home() / ".serverjar" CLIENT_CERT_DIR = Path.home() / ".serverjar" / "client" / "cert" CLIENT_CERT_SUFFIXES = {".pem", ".crt", ".cer"} CLIENT_HISTORY_FILE = Path.home() / ".serverjar" / "history" / "history.txt" CLIENT_LOG_DIR = Path.home() / ".serverjar" / "client" / "logs" def add_client_cert(cert_path): source = Path(cert_path).expanduser() if not source.exists(): raise FileNotFoundError(f"{source} does not exist") if not source.is_file(): raise IsADirectoryError(f"{source} is not a file") if source.suffix.lower() not in CLIENT_CERT_SUFFIXES: allowed = ", ".join(sorted(CLIENT_CERT_SUFFIXES)) raise ValueError(f"Unsupported certificate suffix '{source.suffix}'. Allowed: {allowed}") CLIENT_CERT_DIR.mkdir(parents=True, exist_ok=True) target = CLIENT_CERT_DIR / source.name if source.resolve() == target.resolve(): return target shutil.copy2(source, target) return target def run_cli_action(argv=None): argv = list(sys.argv[1:] if argv is None else argv) if not argv or argv[0] != "--add-cert": return False parser = argparse.ArgumentParser(prog=f"{Path(sys.argv[0]).name} --add-cert") parser.add_argument("cert_path", help="Path to a PEM/CRT/CER certificate file") args = parser.parse_args(argv[1:]) try: target = add_client_cert(args.cert_path) except Exception as e: print(f"Unable to add certificate: {e}", file=sys.stderr) sys.exit(1) print(f"Certificate added: {target}") return True def create_client_tls_context(log=None, warn=None): CLIENT_CERT_DIR.mkdir(parents=True, exist_ok=True) context = ssl.create_default_context() loaded_certs = [] for cert_path in sorted(CLIENT_CERT_DIR.iterdir()): if not cert_path.is_file() or cert_path.suffix.lower() not in CLIENT_CERT_SUFFIXES: continue try: context.load_verify_locations(cafile=cert_path) except ssl.SSLError as e: if callable(warn): warn(f"Unable to load TLS certificate {cert_path}: {e}") continue loaded_certs.append(cert_path.name) if callable(log): if loaded_certs: log("Loaded TLS certificate(s): {}".format(", ".join(loaded_certs))) else: log(f"No custom TLS certificates found in {CLIENT_CERT_DIR}") return context def get_history(log=None, warn=None): if not CLIENT_HISTORY_FILE.exists(): return [] CLIENT_HISTORY_FILE.parent.mkdir(parents=True, exist_ok=True) if callable(log): log("Restoring command history...") try: with CLIENT_HISTORY_FILE.open("r", encoding="utf-8") as f: return [line for line in f.read().splitlines() if line] except Exception as e: if callable(warn): warn(f"Unable to restore command history: {e}") return [] def save_history(history, log=None, warn=None): CLIENT_HISTORY_FILE.parent.mkdir(parents=True, exist_ok=True) if callable(log): log("Saving command history...") try: with CLIENT_HISTORY_FILE.open("w", encoding="utf-8") as f: f.write("\n".join(history)) except Exception as e: if callable(warn): warn(f"Unable to save command history: {e}") class ServerJarClient(Application): def __init__(self, **kwargs): super().__init__(**kwargs, mouse_support=True) # Text style self.style = Style.from_dict({ "input": "bg:#222222 #ffffff", "separator-area": "bg:#000000 #ffffff", "message-area": "bg:#111111 #ffffff", "log": "bg:#000000 #ffffff", "warning": "bg:#000000 ansiyellow", "error": "bg:#000000 ansibrightred bold", "system": "bg:#000000 ansicyan", "process-log": "bg:#000000 ansigreen", "process-error": "bg:#000000 ansired", "unknown": "bg:#000000 ansiwhite bold", }) class LogLexer(Lexer): tag_pattern = re.compile(r"^\[([A-Za-z0-9_. -]+)([:|][A-Za-z0-9_. -]+)?\]") line_style = { "auth_err": "class:error", "auth_ok": "class:process-log", "auth_required": "class:warning", "client": "class:log", "client|err": "class:error", "client|warn": "class:warning", "download_log_begin": "class:system", "download_log_end": "class:system", "err": "class:process-error", "log": "class:process-log", "ok": "class:process-log", "sys": "class:system", "sys:err": "class:error", "unknown": "class:unknown", } def lex_document(self, document): def get_line(lineno): line = document.lines[lineno] match = self.tag_pattern.match(line) tag = "unknown" if match: tag = (match.group(1) + (match.group(2) or "")).lower() style = self.line_style.get(tag, self.line_style["unknown"]) return [(style, line)] return get_line # Socket self.sock = None # Args self.args = None # event self.kb = KeyBindings() # Areas # self.log_lines = [] # self.log_control = FormattedTextControl( # text=lambda: ANSI("".join(self.log_lines)) # ) self.log_area = TextArea( style="class:log", wrap_lines=True, lexer=LogLexer(), ) self.separator_area = TextArea(text="=" * 10 + " Enter Command Here " + "=" * 10, height=1, style="class:separator-area") self.message_area = TextArea(height=1, multiline=False, style="class:message-area") self.input_area = TextArea(height=1, prompt="> ", style="class:input", multiline=False) self.layout = Layout(HSplit([ self.log_area, self.separator_area, self.message_area, self.input_area, ])) # Thread self.sock_lock = threading.Lock() self.closing_event = threading.Event() self.client_thread = threading.Thread(target=self.client, daemon=True) self.connect_event = threading.Event() self.disconnect_event = threading.Event() self.auth_required = False # Queue self.incoming = queue.Queue() # Command History self.cmds = [] self.current_index = None self.history_draft = "" @self.kb.add("c-c") def closing_kb(event): self.shutdown("Ctrl-C (Stopped by user)") @self.kb.add("up", filter=has_focus(self.input_area)) def get_old_cmd(event): # check if there's no old command available in the command history list if not self.cmds: return # Save the current input so Down can restore it after browsing history. if self.current_index is None: self.history_draft = self.get_input_area_output() self.current_index = 0 elif self.current_index < len(self.cmds) - 1: self.current_index += 1 old_cmd = self.cmds[self.current_index] self.set_input_area_text(old_cmd) @self.kb.add("down", filter=has_focus(self.input_area)) def get_new_cmd(event): # Check if there's no old command available in the command history list if len(self.cmds) < 2: return # Get new command only working when current_index greater then 0 if self.current_index is None: return if self.current_index == 0: self.current_index = None self.set_input_area_text(self.history_draft) return self.current_index -= 1 new_cmd = self.cmds[self.current_index] self.set_input_area_text(new_cmd) @self.kb.add("enter", filter=has_focus(self.input_area)) def enter_kb(event): cmd = self.input_area.text self.input_area.text = "" self.current_index = None self.history_draft = "" if cmd == "_exit": self.shutdown("_exit command detected") return if self.auth_required: if cmd == "_d": self.command_parser(cmd) return with self.sock_lock: s = self.sock if s: try: s.sendall(("__auth " + cmd + "\n").encode("utf-8")) self.display_message("Password sent, waiting for server...") except OSError as e: self._err(f"Send failed: {e}\n") else: self._err("The remote server is not connected yet.") return # Save command self.insert_new_cmd_to_history(cmd) # if re.match(r"^_[A-Za-z0-9]+(?:$|_.*)", cmd): exit_flag = self.command_parser(cmd) if exit_flag: return with self.sock_lock: s = self.sock if s: try: s.sendall((cmd + "\n").encode("utf-8")) except OSError as e: self._err(f"Send failed: {e}\n") else: self._err("The remote server is not connected yet.") @self.kb.add("c-w", filter=has_focus(self.input_area)) def focus_log_area(event): self.layout.focus(self.log_area) self.display_message("Now focus at log area.") @self.kb.add("c-w", filter=has_focus(self.log_area)) def focus_log_area(event): self.layout.focus(self.input_area) self.display_message("Now focus at input area.") self.key_bindings = self.kb self.full_screen = True # Host and Port self.host = None self.port = None def command_parser(self, command): def connect_to_server(host, port): # update target self.host = host self.port = port # trigger connection self.disconnect_event.clear() self.connect_event.set() return True def disconnect_from_server(cmd): self._log("Disconnecting...") self.disconnect_event.set() self.connect_event.clear() with self.sock_lock: s = self.sock if s: try: s.shutdown(socket.SHUT_RDWR) except Exception as e: self._err("An error occurred while shutting down sock: " + str(e)) try: s.close() except Exception as e: self._err("An error occurred while closing the socket: {}".format(e)) elif self.connect_event.is_set(): self.host = None self.port = None self._log("Auto-reconnect stopped.") else: self._err("The remote server is not connected yet.") return True def connect_to_server_parser(cmd): target = cmd[3:].strip() try: host, port_str = target.split(":", 1) host = host.strip() port = int(port_str.strip()) if not host: raise ValueError("empty host") except Exception as _: self._err("Usage: _c host:port") return True connect_to_server(host, port) return True def _shutdown(cmd): self.shutdown("_exit command detected") return True def _top(cmd): self.log_area.buffer.cursor_position = 0 self.invalidate() return True def _bottom(cmd): self.log_area.buffer.cursor_position = len(self.log_area.buffer.text) self.invalidate() return True def _version(cmd): self._log("ServerJar Client Version {}".format(version)) return True def _help(cmd): for key, value in cmd_map.items(): self._log(f"{key}: {value.get('description')}") return True def _clear(cmd): self.log_area.text = "" self._log("Log cleared") return True def _clear_history(cmd): self.cmds = [] self.current_index = None self.history_draft = "" self._log("History cleared") return True cmd_map = { "_exit": { "func": _shutdown, "description": "Exit the shell", }, "_c": { "func": connect_to_server_parser, "description": "Connect to the remote server (Usage: _c host:port)", }, "_d": { "func": disconnect_from_server, "description": "Disconnect from the remote server", }, "_top": { "func": _top, "description": "Go to the top of the log area", }, "_bottom": { "func": _bottom, "description": "Go to the bottom of the log area", }, "_version": { "func": _version, "description": "Display the version of the client", }, "_clear_history": { "func": _clear_history, "description": "Clear the command history", }, "_clear": { "func": _clear, "description": "Clear the log area", }, "_help": { "func": _help, "description": "Display the help message", }, } if not command.strip(): return True header = command.split()[0] for cmd in cmd_map.keys(): if cmd == header: func = cmd_map.get(cmd).get("func") return_flag = func(command) return return_flag # self._err("Unknown command '%s'" % command) return False def display_message(self, message): self.message_area.text = message def get_input_area_output(self): return self.input_area.text def set_input_area_text(self, text): self.input_area.text = text self.input_area.buffer.cursor_position = len(text) def insert_new_cmd_to_history(self, cmd): if not cmd: return if self.cmds and self.cmds[0] == cmd: return self.cmds.insert(0, cmd) class ServerInfoInvalidException(Exception): def __init__(self, message, **kwargs): super().__init__() self.msg = message def __str__(self): return self.msg def arguments_parser(self): parser = argparse.ArgumentParser() parser.add_argument("-p", "--port", type=int, help="Port number", required=False) parser.add_argument('-host', '--host', type=str, help="Hostname", required=False) parser.add_argument('-no-tls', '--no-tls', help="Enable TLS support", action='store_true', default=False, required=False) parser.add_argument('-r', '--retry', help="Retry when disconnect", action='store_true', default=False, required=False) parser.add_argument('--add-cert', help="Add server certificate", type=str, default=None, required=False) args = parser.parse_args() return args def get_tls_context(self): return create_client_tls_context(log=self._log, warn=self._warn) def shutdown(self, reason=""): if self.closing_event.is_set(): return self.closing_event.set() self._log(f"Shutting down for reason: {reason}") with self.sock_lock: s = self.sock if s: try: s.close() except Exception as e: self._err(f"Unable to close socket: {e}") pass # Save history save_history(self.cmds, self._log, self._warn) # Exit ui event loop self.full_exit() # @staticmethod # def clear_screen(): # os.system("cls" if os.name == "nt" else "clear") @staticmethod def _ensure_log_tag(message): message = f"{message}" if re.match(r"^\[[A-Za-z0-9_. -]+(?:[:|][A-Za-z0-9_. -]+)?\]", message): return message return f"[unknown] {message}" def log(self, message): self.incoming.put(self._ensure_log_tag(message)) def _log(self, message): # Nothing change self.incoming.put(f"[client] {message}") # self.display_message(f"{message}") def _err(self, message): # WIP... (display text as red color if the log is an error message) self.incoming.put(f"[client|err] {message}") # self.display_message(f"ERROR: {message}") def _warn(self, message): # WIP... (Display text as yellow color if the log is a warning message) self.incoming.put(f"[client|warn] {message}") # self.display_message(f"WARNING: {message}") def full_exit(self): self.exit() sys.exit() async def consume_incoming(self): loop = asyncio.get_running_loop() while True: msg = await loop.run_in_executor(None, self.incoming.get) # self.log_lines.append(msg) # # if len(self.log_lines) > 2000: # self.log_lines = self.log_lines[-2000:] if len(self.log_area.text) > 0: self.log_area.text += "\n" + msg else: self.log_area.text += msg if len(self.log_area.text) > 300_000: self.log_area.text = "New log start here.\n" + self.log_area.text[-250_000:] self.log_area.buffer.cursor_position = len(self.log_area.buffer.text) self.invalidate() def client(self): ptk_clear() while not self.closing_event.is_set(): if self.closing_event.is_set(): break if not self.connect_event.is_set() and (self.args.port is None or self.args.host is None): self._log("Type _c host:port to connect. (_d to disconnect)") self.connect_event.wait() elif self.connect_event.is_set(): self._log(f"Reconnecting...") else: self._log(f"Connecting to remote server from {self.args.host}:{self.args.port} (Value from sys.argv)...") self.host, self.port = self.args.host, self.args.port if not self.host or not self.port: self._err("No host/port set. Usage: _c host:port") self.connect_event.clear() continue try: self._log(f"Connecting to {self.host}:{self.port} ...") # Create connect if self.args.no_tls: s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.connect((self.host, self.port)) else: raw = socket.create_connection((self.host, self.port)) context = self.get_tls_context() s = context.wrap_socket(raw, server_hostname=self.host) with self.sock_lock: self.sock = s self.auth_required = False self._log("Remote socket server connected [HOST: {}, PORT: {}]".format(self.host, self.port)) buffer = "" downloading_log = False download_path = None download_file = None while True: # Receive remote server broadcast message and display it on log area data = s.recv(4096) if not data: raise ConnectionError("Server closed") buffer += data.decode("utf-8", errors="replace") while "\n" in buffer: line, buffer = buffer.split("\n", 1) if line.startswith("[AUTH_REQUIRED]"): self.auth_required = True self.display_message("Password required. Type the server password to continue.") self._warn("Server requires a password before continuing.") if self.args.no_tls: self._warn("TLS is disabled. The password will be sent without encryption.") continue if line.startswith("[AUTH_OK]"): self.auth_required = False self.display_message("Authenticated.") self._log("Server password accepted.") continue if line.startswith("[AUTH_ERR]"): self.auth_required = True self.display_message("Invalid password. Try again, or use _d to disconnect.") self._err(line) continue if line.startswith("[DOWNLOAD_LOG_BEGIN]"): if download_file: download_file.close() file_name = line[len("[DOWNLOAD_LOG_BEGIN]"):].strip() file_name = Path(file_name).name or "serverjar.log" CLIENT_LOG_DIR.mkdir(parents=True, exist_ok=True) download_path = CLIENT_LOG_DIR / file_name download_file = download_path.open("wb") downloading_log = True self._log(f"Downloading log to {download_path}") continue if line == "[DOWNLOAD_LOG_END]": if download_file: download_file.close() download_file = None downloading_log = False self._log(f"Log downloaded: {download_path}") download_path = None continue if downloading_log: if line.startswith("["): if line.startswith("[SYS:ERR]"): if download_file: download_file.close() download_file = None downloading_log = False download_path = None self.log(line) continue try: if download_file: download_file.write(base64.b64decode(line.encode("ascii"))) except (binascii.Error, OSError) as e: self._err(f"Unable to write downloaded log: {e}") if download_file: download_file.close() download_file = None downloading_log = False continue # ### Use normal log method ### self.log(line) except (ConnectionError, OSError) as e: if not self.closing_event.is_set(): if self.args.retry: self._warn(f"Disconnected: {e}, retrying...") else: self._err(f"Disconnected: {e}") time.sleep(1) except KeyboardInterrupt: self._log("Exiting...") break except Exception as e: self._err(f"Unhandled exception: {e}") self._err(f"{traceback.format_exc()}") finally: if "download_file" in locals() and download_file: download_file.close() with self.sock_lock: try: if self.sock: self._log("Closing remote connection (From {}:{})...".format(self.host, self.port)) self.sock.close() except Exception as e: self._err(f"Unable to close socket: {e}") pass self.sock = None self.auth_required = False # reset flags self.disconnect_event.clear() if not self.args.retry: self.connect_event.clear() def startup(self): self.args = self.arguments_parser() self.layout.focus(self.input_area) self.cmds = get_history(self._log, self._warn) asyncio.create_task(self.consume_incoming()) self.client_thread.start() if __name__ == "__main__": if run_cli_action(): sys.exit(0) app = ServerJarClient() app.run(pre_run=app.startup)