Files
ServerJar/client.py
T
2026-05-18 23:29:58 +08:00

771 lines
26 KiB
Python

import argparse
import asyncio
import base64
import binascii
import queue
import re
import shutil
import ssl
import sys
import threading
import socket
import time
import traceback
from pathlib import Path
from prompt_toolkit import Application
from prompt_toolkit.layout import Layout, HSplit
from prompt_toolkit.lexers import Lexer
from prompt_toolkit.widgets import TextArea
from prompt_toolkit.key_binding import KeyBindings
from prompt_toolkit.styles import Style
from prompt_toolkit.filters import has_focus
from prompt_toolkit.shortcuts import clear as ptk_clear
version = "Beta-1"
SERVER_JAR_DIR = Path.home() / ".serverjar"
CLIENT_CERT_DIR = Path.home() / ".serverjar" / "client" / "cert"
CLIENT_CERT_SUFFIXES = {".pem", ".crt", ".cer"}
CLIENT_HISTORY_FILE = Path.home() / ".serverjar" / "history" / "history.txt"
CLIENT_LOG_DIR = Path.home() / ".serverjar" / "client" / "logs"
def add_client_cert(cert_path):
source = Path(cert_path).expanduser()
if not source.exists():
raise FileNotFoundError(f"{source} does not exist")
if not source.is_file():
raise IsADirectoryError(f"{source} is not a file")
if source.suffix.lower() not in CLIENT_CERT_SUFFIXES:
allowed = ", ".join(sorted(CLIENT_CERT_SUFFIXES))
raise ValueError(f"Unsupported certificate suffix '{source.suffix}'. Allowed: {allowed}")
CLIENT_CERT_DIR.mkdir(parents=True, exist_ok=True)
target = CLIENT_CERT_DIR / source.name
if source.resolve() == target.resolve():
return target
shutil.copy2(source, target)
return target
def run_cli_action(argv=None):
argv = list(sys.argv[1:] if argv is None else argv)
if not argv or argv[0] != "--add-cert":
return False
parser = argparse.ArgumentParser(prog=f"{Path(sys.argv[0]).name} --add-cert")
parser.add_argument("cert_path", help="Path to a PEM/CRT/CER certificate file")
args = parser.parse_args(argv[1:])
try:
target = add_client_cert(args.cert_path)
except Exception as e:
print(f"Unable to add certificate: {e}", file=sys.stderr)
sys.exit(1)
print(f"Certificate added: {target}")
return True
def create_client_tls_context(log=None, warn=None):
CLIENT_CERT_DIR.mkdir(parents=True, exist_ok=True)
context = ssl.create_default_context()
loaded_certs = []
for cert_path in sorted(CLIENT_CERT_DIR.iterdir()):
if not cert_path.is_file() or cert_path.suffix.lower() not in CLIENT_CERT_SUFFIXES:
continue
try:
context.load_verify_locations(cafile=cert_path)
except ssl.SSLError as e:
if callable(warn):
warn(f"Unable to load TLS certificate {cert_path}: {e}")
continue
loaded_certs.append(cert_path.name)
if callable(log):
if loaded_certs:
log("Loaded TLS certificate(s): {}".format(", ".join(loaded_certs)))
else:
log(f"No custom TLS certificates found in {CLIENT_CERT_DIR}")
return context
def get_history(log=None, warn=None):
if not CLIENT_HISTORY_FILE.exists():
return []
CLIENT_HISTORY_FILE.parent.mkdir(parents=True, exist_ok=True)
if callable(log):
log("Restoring command history...")
try:
with CLIENT_HISTORY_FILE.open("r", encoding="utf-8") as f:
return [line for line in f.read().splitlines() if line]
except Exception as e:
if callable(warn):
warn(f"Unable to restore command history: {e}")
return []
def save_history(history, log=None, warn=None):
CLIENT_HISTORY_FILE.parent.mkdir(parents=True, exist_ok=True)
if callable(log):
log("Saving command history...")
try:
with CLIENT_HISTORY_FILE.open("w", encoding="utf-8") as f:
f.write("\n".join(history))
except Exception as e:
if callable(warn):
warn(f"Unable to save command history: {e}")
class ServerJarClient(Application):
def __init__(self, **kwargs):
super().__init__(**kwargs, mouse_support=True)
# Text style
self.style = Style.from_dict({
"input": "bg:#222222 #ffffff",
"separator-area": "bg:#000000 #ffffff",
"message-area": "bg:#111111 #ffffff",
"log": "bg:#000000 #ffffff",
"warning": "bg:#000000 ansiyellow",
"error": "bg:#000000 ansibrightred bold",
"system": "bg:#000000 ansicyan",
"process-log": "bg:#000000 ansigreen",
"process-error": "bg:#000000 ansired",
"unknown": "bg:#000000 ansiwhite bold",
})
class LogLexer(Lexer):
tag_pattern = re.compile(r"^\[([A-Za-z0-9_. -]+)([:|][A-Za-z0-9_. -]+)?\]")
line_style = {
"auth_err": "class:error",
"auth_ok": "class:process-log",
"auth_required": "class:warning",
"client": "class:log",
"client|err": "class:error",
"client|warn": "class:warning",
"download_log_begin": "class:system",
"download_log_end": "class:system",
"err": "class:process-error",
"log": "class:process-log",
"ok": "class:process-log",
"sys": "class:system",
"sys:err": "class:error",
"unknown": "class:unknown",
}
def lex_document(self, document):
def get_line(lineno):
line = document.lines[lineno]
match = self.tag_pattern.match(line)
tag = "unknown"
if match:
tag = (match.group(1) + (match.group(2) or "")).lower()
style = self.line_style.get(tag, self.line_style["unknown"])
return [(style, line)]
return get_line
# Socket
self.sock = None
# Args
self.args = None
# event
self.kb = KeyBindings()
# Areas
# self.log_lines = []
# self.log_control = FormattedTextControl(
# text=lambda: ANSI("".join(self.log_lines))
# )
self.log_area = TextArea(
style="class:log",
wrap_lines=True,
lexer=LogLexer(),
)
self.separator_area = TextArea(text="=" * 10 + " Enter Command Here " + "=" * 10, height=1,
style="class:separator-area")
self.message_area = TextArea(height=1, multiline=False, style="class:message-area")
self.input_area = TextArea(height=1, prompt="> ", style="class:input", multiline=False)
self.layout = Layout(HSplit([
self.log_area,
self.separator_area,
self.message_area,
self.input_area,
]))
# Thread
self.sock_lock = threading.Lock()
self.closing_event = threading.Event()
self.client_thread = threading.Thread(target=self.client, daemon=True)
self.connect_event = threading.Event()
self.disconnect_event = threading.Event()
self.auth_required = False
# Queue
self.incoming = queue.Queue()
# Command History
self.cmds = []
self.current_index = None
self.history_draft = ""
@self.kb.add("c-c")
def closing_kb(event):
self.shutdown("Ctrl-C (Stopped by user)")
@self.kb.add("up", filter=has_focus(self.input_area))
def get_old_cmd(event):
# check if there's no old command available in the command history list
if not self.cmds:
return
# Save the current input so Down can restore it after browsing history.
if self.current_index is None:
self.history_draft = self.get_input_area_output()
self.current_index = 0
elif self.current_index < len(self.cmds) - 1:
self.current_index += 1
old_cmd = self.cmds[self.current_index]
self.set_input_area_text(old_cmd)
@self.kb.add("down", filter=has_focus(self.input_area))
def get_new_cmd(event):
# Check if there's no old command available in the command history list
if len(self.cmds) < 2:
return
# Get new command only working when current_index greater then 0
if self.current_index is None:
return
if self.current_index == 0:
self.current_index = None
self.set_input_area_text(self.history_draft)
return
self.current_index -= 1
new_cmd = self.cmds[self.current_index]
self.set_input_area_text(new_cmd)
@self.kb.add("enter", filter=has_focus(self.input_area))
def enter_kb(event):
cmd = self.input_area.text
self.input_area.text = ""
self.current_index = None
self.history_draft = ""
if cmd == "_exit":
self.shutdown("_exit command detected")
return
if self.auth_required:
if cmd == "_d":
self.command_parser(cmd)
return
with self.sock_lock:
s = self.sock
if s:
try:
s.sendall(("__auth " + cmd + "\n").encode("utf-8"))
self.display_message("Password sent, waiting for server...")
except OSError as e:
self._err(f"Send failed: {e}\n")
else:
self._err("The remote server is not connected yet.")
return
# Save command
self.insert_new_cmd_to_history(cmd)
# if re.match(r"^_[A-Za-z0-9]+(?:$|_.*)", cmd):
exit_flag = self.command_parser(cmd)
if exit_flag:
return
with self.sock_lock:
s = self.sock
if s:
try:
s.sendall((cmd + "\n").encode("utf-8"))
except OSError as e:
self._err(f"Send failed: {e}\n")
else:
self._err("The remote server is not connected yet.")
@self.kb.add("c-w", filter=has_focus(self.input_area))
def focus_log_area(event):
self.layout.focus(self.log_area)
self.display_message("Now focus at log area.")
@self.kb.add("c-w", filter=has_focus(self.log_area))
def focus_log_area(event):
self.layout.focus(self.input_area)
self.display_message("Now focus at input area.")
self.key_bindings = self.kb
self.full_screen = True
# Host and Port
self.host = None
self.port = None
def command_parser(self, command):
def connect_to_server(host, port):
# update target
self.host = host
self.port = port
# trigger connection
self.disconnect_event.clear()
self.connect_event.set()
return True
def disconnect_from_server(cmd):
self._log("Disconnecting...")
self.disconnect_event.set()
self.connect_event.clear()
with self.sock_lock:
s = self.sock
if s:
try:
s.shutdown(socket.SHUT_RDWR)
except Exception as e:
self._err("An error occurred while shutting down sock: " + str(e))
try:
s.close()
except Exception as e:
self._err("An error occurred while closing the socket: {}".format(e))
elif self.connect_event.is_set():
self.host = None
self.port = None
self._log("Auto-reconnect stopped.")
else:
self._err("The remote server is not connected yet.")
return True
def connect_to_server_parser(cmd):
target = cmd[3:].strip()
try:
host, port_str = target.split(":", 1)
host = host.strip()
port = int(port_str.strip())
if not host:
raise ValueError("empty host")
except Exception as _:
self._err("Usage: _c host:port")
return True
connect_to_server(host, port)
return True
def _shutdown(cmd):
self.shutdown("_exit command detected")
return True
def _top(cmd):
self.log_area.buffer.cursor_position = 0
self.invalidate()
return True
def _bottom(cmd):
self.log_area.buffer.cursor_position = len(self.log_area.buffer.text)
self.invalidate()
return True
def _version(cmd):
self._log("ServerJar Client Version {}".format(version))
return True
def _help(cmd):
for key, value in cmd_map.items():
self._log(f"{key}: {value.get('description')}")
return True
def _clear(cmd):
self.log_area.text = ""
self._log("Log cleared")
return True
def _clear_history(cmd):
self.cmds = []
self.current_index = None
self.history_draft = ""
self._log("History cleared")
return True
cmd_map = {
"_exit": {
"func": _shutdown,
"description": "Exit the shell",
},
"_c": {
"func": connect_to_server_parser,
"description": "Connect to the remote server (Usage: _c host:port)",
},
"_d": {
"func": disconnect_from_server,
"description": "Disconnect from the remote server",
},
"_top": {
"func": _top,
"description": "Go to the top of the log area",
},
"_bottom": {
"func": _bottom,
"description": "Go to the bottom of the log area",
},
"_version": {
"func": _version,
"description": "Display the version of the client",
},
"_clear_history": {
"func": _clear_history,
"description": "Clear the command history",
},
"_clear": {
"func": _clear,
"description": "Clear the log area",
},
"_help": {
"func": _help,
"description": "Display the help message",
},
}
if not command.strip():
return True
header = command.split()[0]
for cmd in cmd_map.keys():
if cmd == header:
func = cmd_map.get(cmd).get("func")
return_flag = func(command)
return return_flag
# self._err("Unknown command '%s'" % command)
return False
def display_message(self, message):
self.message_area.text = message
def get_input_area_output(self):
return self.input_area.text
def set_input_area_text(self, text):
self.input_area.text = text
self.input_area.buffer.cursor_position = len(text)
def insert_new_cmd_to_history(self, cmd):
if not cmd:
return
if self.cmds and self.cmds[0] == cmd:
return
self.cmds.insert(0, cmd)
class ServerInfoInvalidException(Exception):
def __init__(self, message, **kwargs):
super().__init__()
self.msg = message
def __str__(self):
return self.msg
def arguments_parser(self):
parser = argparse.ArgumentParser()
parser.add_argument("-p", "--port", type=int, help="Port number", required=False)
parser.add_argument('-host', '--host', type=str, help="Hostname", required=False)
parser.add_argument('-no-tls', '--no-tls', help="Enable TLS support", action='store_true', default=False,
required=False)
parser.add_argument('-r', '--retry', help="Retry when disconnect", action='store_true', default=False,
required=False)
parser.add_argument('--add-cert', help="Add server certificate", type=str, default=None, required=False)
args = parser.parse_args()
return args
def get_tls_context(self):
return create_client_tls_context(log=self._log, warn=self._warn)
def shutdown(self, reason=""):
if self.closing_event.is_set():
return
self.closing_event.set()
self._log(f"Shutting down for reason: {reason}")
with self.sock_lock:
s = self.sock
if s:
try:
s.close()
except Exception as e:
self._err(f"Unable to close socket: {e}")
pass
# Save history
save_history(self.cmds, self._log, self._warn)
# Exit ui event loop
self.full_exit()
# @staticmethod
# def clear_screen():
# os.system("cls" if os.name == "nt" else "clear")
@staticmethod
def _ensure_log_tag(message):
message = f"{message}"
if re.match(r"^\[[A-Za-z0-9_. -]+(?:[:|][A-Za-z0-9_. -]+)?\]", message):
return message
return f"[unknown] {message}"
def log(self, message):
self.incoming.put(self._ensure_log_tag(message))
def _log(self, message):
# Nothing change
self.incoming.put(f"[client] {message}")
# self.display_message(f"{message}")
def _err(self, message):
# WIP... (display text as red color if the log is an error message)
self.incoming.put(f"[client|err] {message}")
# self.display_message(f"ERROR: {message}")
def _warn(self, message):
# WIP... (Display text as yellow color if the log is a warning message)
self.incoming.put(f"[client|warn] {message}")
# self.display_message(f"WARNING: {message}")
def full_exit(self):
self.exit()
sys.exit()
async def consume_incoming(self):
loop = asyncio.get_running_loop()
while True:
msg = await loop.run_in_executor(None, self.incoming.get)
# self.log_lines.append(msg)
#
# if len(self.log_lines) > 2000:
# self.log_lines = self.log_lines[-2000:]
if len(self.log_area.text) > 0:
self.log_area.text += "\n" + msg
else:
self.log_area.text += msg
if len(self.log_area.text) > 300_000:
self.log_area.text = "New log start here.\n" + self.log_area.text[-250_000:]
self.log_area.buffer.cursor_position = len(self.log_area.buffer.text)
self.invalidate()
def client(self):
ptk_clear()
while not self.closing_event.is_set():
if self.closing_event.is_set():
break
if not self.connect_event.is_set() and (self.args.port is None or self.args.host is None):
self._log("Type _c host:port to connect. (_d to disconnect)")
self.connect_event.wait()
elif self.connect_event.is_set():
self._log(f"Reconnecting...")
else:
self._log(f"Connecting to remote server from {self.args.host}:{self.args.port} (Value from sys.argv)...")
self.host, self.port = self.args.host, self.args.port
if not self.host or not self.port:
self._err("No host/port set. Usage: _c host:port")
self.connect_event.clear()
continue
try:
self._log(f"Connecting to {self.host}:{self.port} ...")
# Create connect
if self.args.no_tls:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.connect((self.host, self.port))
else:
raw = socket.create_connection((self.host, self.port))
context = self.get_tls_context()
s = context.wrap_socket(raw, server_hostname=self.host)
with self.sock_lock:
self.sock = s
self.auth_required = False
self._log("Remote socket server connected [HOST: {}, PORT: {}]".format(self.host, self.port))
buffer = ""
downloading_log = False
download_path = None
download_file = None
while True:
# Receive remote server broadcast message and display it on log area
data = s.recv(4096)
if not data:
raise ConnectionError("Server closed")
buffer += data.decode("utf-8", errors="replace")
while "\n" in buffer:
line, buffer = buffer.split("\n", 1)
if line.startswith("[AUTH_REQUIRED]"):
self.auth_required = True
self.display_message("Password required. Type the server password to continue.")
self._warn("Server requires a password before continuing.")
if self.args.no_tls:
self._warn("TLS is disabled. The password will be sent without encryption.")
continue
if line.startswith("[AUTH_OK]"):
self.auth_required = False
self.display_message("Authenticated.")
self._log("Server password accepted.")
continue
if line.startswith("[AUTH_ERR]"):
self.auth_required = True
self.display_message("Invalid password. Try again, or use _d to disconnect.")
self._err(line)
continue
if line.startswith("[DOWNLOAD_LOG_BEGIN]"):
if download_file:
download_file.close()
file_name = line[len("[DOWNLOAD_LOG_BEGIN]"):].strip()
file_name = Path(file_name).name or "serverjar.log"
CLIENT_LOG_DIR.mkdir(parents=True, exist_ok=True)
download_path = CLIENT_LOG_DIR / file_name
download_file = download_path.open("wb")
downloading_log = True
self._log(f"Downloading log to {download_path}")
continue
if line == "[DOWNLOAD_LOG_END]":
if download_file:
download_file.close()
download_file = None
downloading_log = False
self._log(f"Log downloaded: {download_path}")
download_path = None
continue
if downloading_log:
if line.startswith("["):
if line.startswith("[SYS:ERR]"):
if download_file:
download_file.close()
download_file = None
downloading_log = False
download_path = None
self.log(line)
continue
try:
if download_file:
download_file.write(base64.b64decode(line.encode("ascii")))
except (binascii.Error, OSError) as e:
self._err(f"Unable to write downloaded log: {e}")
if download_file:
download_file.close()
download_file = None
downloading_log = False
continue
# ### Use normal log method ###
self.log(line)
except (ConnectionError, OSError) as e:
if not self.closing_event.is_set():
if self.args.retry:
self._warn(f"Disconnected: {e}, retrying...")
else:
self._err(f"Disconnected: {e}")
time.sleep(1)
except KeyboardInterrupt:
self._log("Exiting...")
break
except Exception as e:
self._err(f"Unhandled exception: {e}")
self._err(f"{traceback.format_exc()}")
finally:
if "download_file" in locals() and download_file:
download_file.close()
with self.sock_lock:
try:
if self.sock:
self._log("Closing remote connection (From {}:{})...".format(self.host, self.port))
self.sock.close()
except Exception as e:
self._err(f"Unable to close socket: {e}")
pass
self.sock = None
self.auth_required = False
# reset flags
self.disconnect_event.clear()
if not self.args.retry:
self.connect_event.clear()
def startup(self):
self.args = self.arguments_parser()
self.layout.focus(self.input_area)
self.cmds = get_history(self._log, self._warn)
asyncio.create_task(self.consume_incoming())
self.client_thread.start()
if __name__ == "__main__":
if run_cli_action():
sys.exit(0)
app = ServerJarClient()
app.run(pre_run=app.startup)