在 GitHub 上查看

RNN Probability 策略

概述

RNN Probability 策略移植自 MetaTrader 专家顾问 RNN (barabashkakvn's edition)。原始算法会采集三个间隔等于 RSI 周期的 RSI 数值,并将其送入一个手工构建的概率网络,模拟递归神经网络的判定逻辑。StockSharp 版本通过高级别的蜡烛订阅复现这一流程,并自动将 MetaTrader 中的手数、点值以及止损/止盈距离转换成 StockSharp 的概念。

当最新完成的蜡烛生成 RSI 值后,策略会回溯一倍和两倍 RSI 周期的历史 RSI 数据。三个归一化的数值与八个权重(Weight0Weight7)组合,得出市场下行的概率。该概率被线性映射到 [-1; 1] 区间,符号决定做多还是做空。策略始终只持有一个方向的净头寸,与原始 EA 保持一致。

交易逻辑

  1. 订阅所选蜡烛类型,并使用 AppliedPrice(默认取开盘价)作为输入手动推进 RelativeStrengthIndex 指标。
  2. 将完成的 RSI 数值保存在一个滚动缓冲区,确保可以访问一倍和两倍周期之前的数据。
  3. 将三个 RSI 数值归一化到 [0; 1] 范围后计算概率网络:
    • 当当前 RSI 处于区间下半部分(低于 50)时,使用 Weight0Weight1Weight2Weight3 组合。
    • 当当前 RSI 位于区间上半部分时,使用 Weight4Weight5Weight6Weight7 组合。
  4. 将得到的概率转换为 -1+1 之间的信号。
  5. 若当前无持仓且信号为负,则买入 TradeVolume 手;若信号为零或正,则卖出同样的手数。
  6. 可以选择按点数同时设置止损和止盈。策略会根据 PriceStep 自动换算成绝对价差,并包含 MetaTrader 在 3/5 位报价上使用的额外乘数。
  7. 每次决策都会写入日志,记录 RSI 输入、概率及最终信号,方便复查。

参数

名称 类型 默认值 说明
CandleType DataType 1 小时时间框架 用于生成信号和指标数据的主蜡烛序列。
TradeVolume decimal 1 每次下单的手数。
RsiPeriod int 9 RSI 指标周期,同时决定历史采样间隔。
AppliedPrice AppliedPriceType Open 送入 RSI 的价格类型(开盘、收盘、最高、最低、中值、典型价、加权价等)。
StopLossTakeProfitPips decimal 100 止损和止盈的点数距离,填 0 可禁用保护单。
Weight0Weight7 decimal 6, 96, 90, 35, 64, 83, 66, 50 概率网络的八个权重,取值范围 0~100。

与原始 MetaTrader 专家的差异

  • 移除了邮件通知功能;StockSharp 的日志足以提供同样的反馈。
  • 持仓量固定为 TradeVolume,未实现部分平仓或逐步加仓,与原始代码保持一致。
  • 指标数据通过高级蜡烛订阅提供,无需手动调用 CopyBuffer 或操作指针。
  • 点值换算直接使用品种的 PriceStep,并在 3/5 位报价上自动乘以 10,而不是写死点差。

使用建议

  • 启动前请确认 TradeVolume 与品种的最小手数步长一致;构造函数同时将该值写入 Strategy.Volume
  • 在优化过程中调节八个权重,可以让概率网络适应不同市场环境。
  • 在点差较大的品种或计划手动离场时,可减小 StopLossTakeProfitPips 或设为 0。
  • 将策略添加到图表,可同时查看蜡烛、RSI 以及成交,便于验证神经网络输出。

指标

  • 一个根据所选价格计算的 RelativeStrengthIndex
namespace StockSharp.Samples.Strategies;

using System;
using System.Linq;
using System.Collections.Generic;

using Ecng.Common;
using Ecng.Collections;
using Ecng.Serialization;

using StockSharp.Algo.Indicators;
using StockSharp.Algo.Strategies;
using StockSharp.BusinessEntities;
using StockSharp.Messages;

using StockSharp.Algo;

/// <summary>
/// Probabilistic strategy converted from the RNN MetaTrader expert.
/// It feeds three delayed RSI readings into the original probability lattice and
/// trades in the direction suggested by the neural network output.
/// </summary>
public class RnnProbabilityStrategy : Strategy
{
	public enum AppliedPriceTypes
	{
		Open,
		High,
		Low,
		Close,
		Median,
		Typical,
		Weighted
	}

	private readonly StrategyParam<decimal> _tradeVolume;
	private readonly StrategyParam<int> _rsiPeriod;
	private readonly StrategyParam<AppliedPriceTypes> _appliedPrice;
	private readonly StrategyParam<decimal> _stopLossTakeProfitPips;
	private readonly StrategyParam<decimal> _weight0;
	private readonly StrategyParam<decimal> _weight1;
	private readonly StrategyParam<decimal> _weight2;
	private readonly StrategyParam<decimal> _weight3;
	private readonly StrategyParam<decimal> _weight4;
	private readonly StrategyParam<decimal> _weight5;
	private readonly StrategyParam<decimal> _weight6;
	private readonly StrategyParam<decimal> _weight7;
	private readonly StrategyParam<DataType> _candleType;

	private RelativeStrengthIndex _rsi;
	private readonly List<decimal> _rsiHistory = new();
	private decimal _pipSize;

	/// <summary>
	/// Trade volume expressed in lots.
	/// </summary>
	public decimal TradeVolume
	{
		get => _tradeVolume.Value;
		set => _tradeVolume.Value = value;
	}

	/// <summary>
	/// Averaging period for the RSI indicator.
	/// </summary>
	public int RsiPeriod
	{
		get => _rsiPeriod.Value;
		set => _rsiPeriod.Value = value;
	}

	/// <summary>
	/// Price source forwarded to the RSI indicator.
	/// </summary>
	public AppliedPriceTypes AppliedPrice
	{
		get => _appliedPrice.Value;
		set => _appliedPrice.Value = value;
	}

	/// <summary>
	/// Symmetric stop-loss and take-profit distance expressed in pips.
	/// </summary>
	public decimal StopLossTakeProfitPips
	{
		get => _stopLossTakeProfitPips.Value;
		set => _stopLossTakeProfitPips.Value = value;
	}

	/// <summary>
	/// Neural network weight for the (low, low, low) RSI combination.
	/// </summary>
	public decimal Weight0
	{
		get => _weight0.Value;
		set => _weight0.Value = value;
	}

	/// <summary>
	/// Neural network weight for the (low, low, high) RSI combination.
	/// </summary>
	public decimal Weight1
	{
		get => _weight1.Value;
		set => _weight1.Value = value;
	}

	/// <summary>
	/// Neural network weight for the (low, high, low) RSI combination.
	/// </summary>
	public decimal Weight2
	{
		get => _weight2.Value;
		set => _weight2.Value = value;
	}

	/// <summary>
	/// Neural network weight for the (low, high, high) RSI combination.
	/// </summary>
	public decimal Weight3
	{
		get => _weight3.Value;
		set => _weight3.Value = value;
	}

	/// <summary>
	/// Neural network weight for the (high, low, low) RSI combination.
	/// </summary>
	public decimal Weight4
	{
		get => _weight4.Value;
		set => _weight4.Value = value;
	}

	/// <summary>
	/// Neural network weight for the (high, low, high) RSI combination.
	/// </summary>
	public decimal Weight5
	{
		get => _weight5.Value;
		set => _weight5.Value = value;
	}

	/// <summary>
	/// Neural network weight for the (high, high, low) RSI combination.
	/// </summary>
	public decimal Weight6
	{
		get => _weight6.Value;
		set => _weight6.Value = value;
	}

	/// <summary>
	/// Neural network weight for the (high, high, high) RSI combination.
	/// </summary>
	public decimal Weight7
	{
		get => _weight7.Value;
		set => _weight7.Value = value;
	}

	/// <summary>
	/// Candle series used for indicator calculations and trading decisions.
	/// </summary>
	public DataType CandleType
	{
		get => _candleType.Value;
		set => _candleType.Value = value;
	}

	/// <summary>
	/// Initializes a new instance of the <see cref="RnnProbabilityStrategy"/> class.
	/// </summary>
	public RnnProbabilityStrategy()
	{
		_tradeVolume = Param(nameof(TradeVolume), 1m)
			.SetDisplay("Trade Volume", "Lot size used for each market entry.", "General")
			.SetGreaterThanZero()
			;

		_rsiPeriod = Param(nameof(RsiPeriod), 9)
			.SetDisplay("RSI Period", "Length of the RSI indicator feeding the neural network.", "Indicator")
			.SetRange(2, 200)
			;

		_appliedPrice = Param(nameof(AppliedPrice), AppliedPriceTypes.Open)
			.SetDisplay("Applied Price", "Price type forwarded to the RSI indicator.", "Indicator");

		_stopLossTakeProfitPips = Param(nameof(StopLossTakeProfitPips), 100m)
			.SetDisplay("Stop Loss & Take Profit (pips)", "Distance used for both stop-loss and take-profit levels.", "Risk")
			.SetRange(0m, 1000m)
			;

		_weight0 = Param(nameof(Weight0), 6m)
			.SetDisplay("Weight 0", "Probability weight applied when all RSI inputs are low.", "Model")
			.SetRange(0m, 100m)
			;

		_weight1 = Param(nameof(Weight1), 96m)
			.SetDisplay("Weight 1", "Probability weight for the (low, low, high) branch.", "Model")
			.SetRange(0m, 100m)
			;

		_weight2 = Param(nameof(Weight2), 90m)
			.SetDisplay("Weight 2", "Probability weight for the (low, high, low) branch.", "Model")
			.SetRange(0m, 100m)
			;

		_weight3 = Param(nameof(Weight3), 35m)
			.SetDisplay("Weight 3", "Probability weight for the (low, high, high) branch.", "Model")
			.SetRange(0m, 100m)
			;

		_weight4 = Param(nameof(Weight4), 64m)
			.SetDisplay("Weight 4", "Probability weight for the (high, low, low) branch.", "Model")
			.SetRange(0m, 100m)
			;

		_weight5 = Param(nameof(Weight5), 83m)
			.SetDisplay("Weight 5", "Probability weight for the (high, low, high) branch.", "Model")
			.SetRange(0m, 100m)
			;

		_weight6 = Param(nameof(Weight6), 66m)
			.SetDisplay("Weight 6", "Probability weight for the (high, high, low) branch.", "Model")
			.SetRange(0m, 100m)
			;

		_weight7 = Param(nameof(Weight7), 50m)
			.SetDisplay("Weight 7", "Probability weight for the (high, high, high) branch.", "Model")
			.SetRange(0m, 100m)
			;

		_candleType = Param(nameof(CandleType), TimeSpan.FromHours(1).TimeFrame())
			.SetDisplay("Candle Type", "Primary timeframe used for signal generation.", "General");
	}

	/// <inheritdoc />
	public override IEnumerable<(Security sec, DataType dt)> GetWorkingSecurities()
	{
		return [(Security, CandleType)];
	}

	/// <inheritdoc />
	protected override void OnReseted()
	{
		base.OnReseted();

		_rsi = default;
		_rsiHistory.Clear();
		_pipSize = 0m;
	}

	/// <inheritdoc />
	protected override void OnStarted2(DateTime time)
	{
		base.OnStarted2(time);

		Volume = TradeVolume;

		_pipSize = CalculatePipSize();

		Unit stopLossUnit = null;
		Unit takeProfitUnit = null;

		if (StopLossTakeProfitPips > 0m && _pipSize > 0m)
		{
			var distance = StopLossTakeProfitPips * _pipSize;
			stopLossUnit = new Unit(distance, UnitTypes.Absolute);
			takeProfitUnit = new Unit(distance, UnitTypes.Absolute);
		}

		if (stopLossUnit != null || takeProfitUnit != null)
		{
			StartProtection(
				takeProfit: takeProfitUnit,
				stopLoss: stopLossUnit,
				isStopTrailing: false,
				useMarketOrders: true);
		}

		_rsi = new RelativeStrengthIndex
		{
			Length = RsiPeriod
		};

		var subscription = SubscribeCandles(CandleType);
		subscription.Bind(ProcessCandle).Start();

		var area = CreateChartArea();
		if (area != null)
		{
			DrawCandles(area, subscription);
			DrawIndicator(area, _rsi);
			DrawOwnTrades(area);
		}
	}

	private void ProcessCandle(ICandleMessage candle)
	{
		if (candle.State != CandleStates.Finished)
			return;

		if (_rsi == null)
			return;

		if (RsiPeriod <= 0)
			return;

		var price = AppliedPrice switch
		{
			AppliedPriceTypes.Open => candle.OpenPrice,
			AppliedPriceTypes.High => candle.HighPrice,
			AppliedPriceTypes.Low => candle.LowPrice,
			AppliedPriceTypes.Close => candle.ClosePrice,
			AppliedPriceTypes.Median => (candle.HighPrice + candle.LowPrice) / 2m,
			AppliedPriceTypes.Typical => (candle.HighPrice + candle.LowPrice + candle.ClosePrice) / 3m,
			AppliedPriceTypes.Weighted => (candle.HighPrice + candle.LowPrice + 2m * candle.ClosePrice) / 4m,
			_ => candle.ClosePrice,
		};
		var rsiIndicatorValue = _rsi.Process(new DecimalIndicatorValue(_rsi, price, candle.OpenTime) { IsFinal = true });

		if (!_rsi.IsFormed || rsiIndicatorValue.IsEmpty)
			return;

		var rsiValue = rsiIndicatorValue.ToDecimal();

		_rsiHistory.Add(rsiValue);
		TrimHistory(_rsiHistory, GetHistoryLimit());

		var lastIndex = _rsiHistory.Count - 1;
		var delayedIndex = lastIndex - RsiPeriod;
		var delayedTwiceIndex = lastIndex - (2 * RsiPeriod);

		if (delayedIndex < 0 || delayedTwiceIndex < 0)
			return;

		var p1 = _rsiHistory[lastIndex] / 100m;
		var p2 = _rsiHistory[delayedIndex] / 100m;
		var p3 = _rsiHistory[delayedTwiceIndex] / 100m;

		var probability = CalculateProbability(p1, p2, p3);
		var signal = probability * 2m - 1m;

		LogInfo($"RSI inputs: p1={p1:F4}, p2={p2:F4}, p3={p3:F4}, probability={probability:F4}, signal={signal:F4}");

		if (TradeVolume <= 0m)
			return;

		if (signal < 0m)
		{
			// want long
			if (Position <= 0m)
			{
				var vol = Math.Abs(Position) + TradeVolume;
				BuyMarket(vol);
			}
		}
		else
		{
			// want short
			if (Position >= 0m)
			{
				var vol = Position + TradeVolume;
				SellMarket(vol);
			}
		}
	}

	private decimal CalculateProbability(decimal p1, decimal p2, decimal p3)
	{
		var pn1 = 1m - p1;
		var pn2 = 1m - p2;
		var pn3 = 1m - p3;

		var probability =
			pn1 * (pn2 * (pn3 * Weight0 + p3 * Weight1) +
			        p2 * (pn3 * Weight2 + p3 * Weight3)) +
			p1 * (pn2 * (pn3 * Weight4 + p3 * Weight5) +
			        p2 * (pn3 * Weight6 + p3 * Weight7));

		return probability / 100m;
	}

	private int GetHistoryLimit()
	{
		return Math.Max((2 * RsiPeriod) + 5, RsiPeriod + 1);
	}

	private static void TrimHistory<T>(List<T> source, int maxSize)
	{
		if (maxSize <= 0)
			return;

		if (source.Count <= maxSize)
			return;

		var removeCount = source.Count - maxSize;
		source.RemoveRange(0, removeCount);
	}

	private decimal CalculatePipSize()
	{
		if (Security == null)
			return 0m;

		var step = Security.PriceStep ?? 0m;
		if (step <= 0m)
			return 0m;

		var decimals = GetDecimalPlaces(step);
		if (decimals == 3 || decimals == 5)
			return step * 10m;

		return step;
	}

	private static int GetDecimalPlaces(decimal value)
	{
		value = Math.Abs(value);

		var decimals = 0;

		while (value != Math.Truncate(value) && decimals < 10)
		{
			value *= 10m;
			decimals++;
		}

		return decimals;
	}
}